Skip to content

Commit 0697a44

Browse files
committed
make PoiintSoilver use Problems logdensity function
1 parent 98831cd commit 0697a44

File tree

7 files changed

+20
-15
lines changed

7 files changed

+20
-15
lines changed

dev/doubleMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
() -> begin # optimized loss is indeed lower than with true parameters
138138
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
139139
ϕg = 1:length(prob0.ϕg), θP = prob0.θP))
140-
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.transP, prob0.f, Float32[], int_ϕθP)
140+
loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.transP, prob0.f, Float32[], py, int_ϕθP)
141141
loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc, i_sites)[1]
142142
loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc, i_sites)[1]
143143
#

src/HybridSolver.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve
3232
train_loader_dev = train_loader
3333
end
3434
f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=false)
35+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
3536
pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario)
3637
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
3738
priors = get_hybridproblem_priors(prob; scenario)
3839
priorsP = [priors[k] for k in keys(par_templates.θP)]
3940
priorsM = [priors[k] for k in keys(par_templates.θM)]
4041
#intP = ComponentArrayInterpreter(par_templates.θP)
41-
loss_gf = get_loss_gf(g_dev, transM, transP, f, intϕ;
42+
loss_gf = get_loss_gf(g_dev, transM, transP, f, py, intϕ;
4243
cdev=infer_cdev(gdevs), pbm_covars, n_site_batch=n_batch, priorsP, priorsM,)
4344
# call loss function once
4445
l1 = is_infer ?

src/gf.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ Create a loss function for given
137137
- g(x, ϕ): machine learning model
138138
- transM: transforamtion of parameters at unconstrained space
139139
- f(θMs, θP): mechanistic model
140+
- py: `function(y_pred, y_obs, y_unc)` to compute negative log-likelihood, i.e. cost
140141
- intϕ: interpreter attaching axis with components ϕg and ϕP
141142
- intP: interpreter attaching axis to ζP = ϕP with components used by f,
142143
The default, uses `intϕ(ϕ)` as a template
@@ -160,7 +161,7 @@ and returns a NamedTuple of
160161
- `neg_log_prior`: negative log-prior of `θMs` and `θP`
161162
- `neg_log_prior`: negative log-prior of `θMs` and `θP`
162163
"""
163-
function get_loss_gf(g, transM, transP, f,
164+
function get_loss_gf(g, transM, transP, f, py,
164165
intϕ::AbstractComponentArrayInterpreter,
165166
intP::AbstractComponentArrayInterpreter = ComponentArrayInterpreter(
166167
intϕ(1:length(intϕ)).ϕP);
@@ -178,7 +179,6 @@ function get_loss_gf(g, transM, transP, f,
178179
#, intP = get_concrete(intP)
179180
#inv_transP = inverse(transP), kwargs = kwargs
180181
function loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites)
181-
σ = exp.(y_unc ./ 2)
182182
ϕc = intϕ(ϕ)
183183
# μ_ζP = ϕc.ϕP
184184
# xMP = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices)
@@ -190,7 +190,9 @@ function get_loss_gf(g, transM, transP, f,
190190
y_pred, θMs_pred, θP_pred = gf(
191191
g, transMs, transP, f, xM, xP, CA.getdata(ϕc.ϕg), CA.getdata(ϕc.ϕP),
192192
pbm_covar_indices; cdev, kwargs...)
193-
nLy = sum(abs2, (y_pred .- y_o) ./ σ)
193+
#σ = exp.(y_unc ./ 2)
194+
#nLy = sum(abs2, (y_pred .- y_o) ./ σ)
195+
nLy = py( y_pred, y_o, y_unc)
194196
# logpdf is not typestable for Distribution{Univariate, Continuous}
195197
logpdf_t = (prior, θ) -> logpdf(prior, θ)::eltype(θP_pred)
196198
logpdf_tv = (prior, θ::AbstractVector) -> begin

src/logden_normal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
neg_logden_indep_normal(obs, μ, logσ2s; σfac=1.0)
33
4-
Compute the negative Log-density of `θM` for multiple independent normal distributions,
4+
Compute the negative Log-density of `obs` for multiple independent normal distributions,
55
given estimated means `μ` and estimated log of variance parameters `logσ2s`.
66
77
All the arguments should be vectors of the same length.

test/test_HybridProblem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ test_without_flux = (scenario) -> begin
141141
train_loader = get_hybridproblem_train_dataloader(prob; scenario)
142142
(xM, xP, y_o, y_unc, i_sites) = first(train_loader)
143143
f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false)
144+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
144145
par_templates = get_hybridproblem_par_templates(prob; scenario)
145146
#f(par_templates.θP, hcat(par_templates.θM, par_templates.θM), xP[1:2])
146147
(; transM, transP) = get_hybridproblem_transforms(prob; scenario)
@@ -154,7 +155,7 @@ test_without_flux = (scenario) -> begin
154155
p = p0 = vcat(ϕg0, par_templates.θP .* convert(eltype(ϕg0), 0.8))
155156

156157
# Pass the site-data for the batches as separate vectors wrapped in a tuple
157-
loss_gf = get_loss_gf(g, transM, transP, f, intϕ;
158+
loss_gf = get_loss_gf(g, transM, transP, f, py, intϕ;
158159
pbm_covars, n_site_batch = n_batch, priorsP, priorsM)
159160
(_xM, _xP, _y_o, _y_unc, _i_sites) = first(train_loader)
160161
l1 = loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites)

test/test_doubleMM.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ end
199199
n_site, n_site_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
200200
f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false)
201201
f2 = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true)
202+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
202203
priors = get_hybridproblem_priors(prob; scenario)
203204
priorsP = [priors[k] for k in keys(par_templates.θP)]
204205
priorsM = [priors[k] for k in keys(par_templates.θM)]
@@ -218,9 +219,9 @@ end
218219
pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario)
219220

220221
#loss_gf = get_loss_gf(g, transM, f, intϕ; gdev = identity)
221-
loss_gf = get_loss_gf(g, transM, transP, f, intϕ;
222+
loss_gf = get_loss_gf(g, transM, transP, f, py, intϕ;
222223
pbm_covars, n_site_batch = n_batch, priorsP, priorsM)
223-
loss_gf2 = get_loss_gf(g, transM, transP, f2, intϕ;
224+
loss_gf_site = get_loss_gf(g, transM, transP, f2, py, intϕ;
224225
pbm_covars, n_site_batch = n_site, priorsP, priorsM)
225226
nLjoint = @inferred first(loss_gf(p0, first(train_loader)...))
226227
(xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch) = first(train_loader)
@@ -237,7 +238,7 @@ end
237238
#optprob, Adam(0.02), callback = callback_loss(100), maxiters = 5000);
238239
optprob, Adam(0.02), maxiters = 2000)
239240

240-
(;nLjoint_pen, y_pred, θMs_pred, θP_pred, nLy, neg_log_prior, loss_penalty) = loss_gf2(
241+
(;nLjoint_pen, y_pred, θMs_pred, θP_pred, nLy, neg_log_prior, loss_penalty) = loss_gf_site(
241242
res.u, train_loader.data...)
242243
#(nLjoint, y_pred, θMs_pred, θP, nLy, neg_log_prior, loss_penalty) = loss_gf(p0, xM, xP, y_o, y_unc);
243244
θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true'))

test/test_util_gpu.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ gdev = gpu_device()
1515
if gdev isa MLDataDevices.CUDADevice
1616
@testset "ones_similar_x" begin
1717
B = CUDA.rand(Float32, 5, 2); # GPU matrix
18-
@test HVI.ones_similar_x(B, size(B,1)) isa CuArray
19-
@test HVI.ones_similar_x(ComponentVector(b=B), size(B,1)) isa CuArray
20-
@test HVI.ones_similar_x(B', size(B,1)) isa CuArray
21-
@test HVI.ones_similar_x(@view(B[:,2]), size(B,1)) isa CuArray
22-
@test HVI.ones_similar_x(ComponentVector(b=B)[:,1], size(B,1)) isa CuArray
18+
@test HVI.ones_similar_x(B, size(B,1)) isa CUDA.CuArray
19+
@test HVI.ones_similar_x(ComponentVector(b=B), size(B,1)) isa CUDA.CuArray
20+
@test HVI.ones_similar_x(B', size(B,1)) isa CUDA.CuArray
21+
@test HVI.ones_similar_x(@view(B[:,2]), size(B,1)) isa CUDA.CuArray
22+
@test HVI.ones_similar_x(ComponentVector(b=B)[:,1], size(B,1)) isa CUDA.CuArray
2323
end
2424
end
2525

0 commit comments

Comments
 (0)