Skip to content

Commit 841fecb

Browse files
committed
return updated problem from solvers
1 parent f09e13c commit 841fecb

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

src/HybridSolver.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve
4848
Optimization.AutoZygote())
4949
optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader)
5050
res = Optimization.solve(optprob, solver.alg; kwargs...)
51-
(; ϕ = intϕ(res.u), resopt = res)
51+
ϕ = intϕ(res.u)
52+
θP = cpu_ca(apply_preserve_axes(transP, cpu_ca(ϕ).ϕP))
53+
probo = update(prob; ϕg = cpu_ca(ϕ).ϕg, θP)
54+
(; ϕ, resopt = res, probo)
5255
end
5356

5457
struct HybridPosteriorSolver{A} <: AbstractHybridSolver
@@ -107,7 +110,9 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS
107110
optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader)
108111
res = Optimization.solve(optprob, solver.alg; kwargs...)
109112
ϕc = interpreters.μP_ϕg_unc(res.u)
110-
(; ϕ = ϕc, θP = cpu_ca(apply_preserve_axes(transP, ϕc.μP)), resopt = res, interpreters)
113+
θP = cpu_ca(apply_preserve_axes(transP, ϕc.μP))
114+
probo = update(prob; ϕg = cpu_ca(ϕ).ϕg, θP = θP, ϕunc = cpu_ca(ϕ).unc);
115+
(; ϕ = ϕc, θP, resopt = res, interpreters, probo)
111116
end
112117

113118
function fit_narrow_normal(θi, prior, θmean_quant)

src/gf.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,18 @@ function gf(prob::AbstractHybridProblem, xM, xP, args...;
2929
(; θP, θM) = get_hybridproblem_par_templates(prob; scenario)
3030
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
3131
intP = ComponentArrayInterpreter(θP)
32-
pbm_covar_indices = intP(1:length(intP))[pbm_covars]
32+
pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario)
33+
pbm_covar_indices = CA.getdata(intP(1:length(intP))[pbm_covars])
3334
ζP = inverse(transP)(θP)
3435
g_dev, ϕg_dev, ζP_dev = (gdev(g), gdev(ϕg), gdev(CA.getdata(ζP)))
35-
gf(g_dev, transM, transP, f, xM, xP, ϕg_dev, ζP_dev; cdev, pbm_covar_indices, kwargs...)
36+
gf(g_dev, transM, transP, f, xM, xP, ϕg_dev, ζP_dev, pbm_covar_indices; cdev, kwargs...)
3637
end
3738

3839
function gf(g, transM, transP, f, xM, xP, ϕg, ζP;
3940
cdev = identity, pbm_covars,
40-
intP = ComponentArrayInterpreter(ζP))
41+
intP = ComponentArrayInterpreter(ζP), kwargs...)
4142
pbm_covar_indices = intP(1:length(intP))[pbm_covars]
42-
gf(g, transM, transP, f, xM, xP, ϕg, ζP, pbm_covar_indices)
43+
gf(g, transM, transP, f, xM, xP, ϕg, ζP, pbm_covar_indices; kwargs...)
4344
end
4445

4546

@@ -52,7 +53,6 @@ function gf(g, transM, transP, f, xM, xP, ϕg, ζP, pbm_covar_indices::AbstractV
5253
# # otherwise Zyote fails on cpu_handler
5354
# ζP = copy(ζP)
5455
# end
55-
#Main.@infiltrate_main
5656
#xMP = _append_PBM_covars(xM, intP(ζP), pbm_covars)
5757
xMP = _append_each_covars(xM, CA.getdata(ζP), pbm_covar_indices)
5858
θMs = gtrans(g, transM, xMP, ϕg; cdev)
@@ -99,7 +99,7 @@ function get_loss_gf(g, transM, transP, f, y_o_global,
9999

100100
let g = g, transM = transM, transP = transP, f = f, y_o_global = y_o_global,
101101
intϕ = get_concrete(intϕ),
102-
pbm_covar_indices = intP(1:length(intP))[pbm_covars]
102+
pbm_covar_indices = CA.getdata(intP(1:length(intP))[pbm_covars])
103103
#, intP = get_concrete(intP)
104104
#inv_transP = inverse(transP), kwargs = kwargs
105105
function loss_gf(p, xM, xP, y_o, y_unc, i_sites)

test/test_HybridProblem.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ test_with_flux = (scenario) -> begin
172172
rng = StableRNG(111)
173173
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_batch=11, n_MC=3)
174174
(; ϕ, θP, resopt) = solve(prob, solver; scenario, rng,
175-
callback = callback_loss(100), maxiters = 1200,
175+
#callback = callback_loss(100), maxiters = 1200,
176176
#maxiters = 20 # too small so that it yields error
177-
#maxiters=200,
177+
maxiters=37,
178178
θmean_quant = 0.01, # test constraining mean to initial prediction
179179
gdev = identity
180180
)
@@ -199,6 +199,25 @@ test_with_flux = (scenario) -> begin
199199
);
200200
@test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector
201201
@test cdev.unc.ρsM)[1] > 0
202+
#
203+
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_batch=11, n_MC=3)
204+
test_correlation = () -> begin
205+
n_epoch = 100 # requires
206+
(; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf,
207+
maxiters = n_batches_in_epoch * n_epoch,
208+
callback = callback_loss(n_batches_in_epoch*5)
209+
);
210+
(; θ, y, entropy_ζ) = predict_gf(rng, probo; scenario = scenf, n_sample_pred = 400);
211+
mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1])
212+
residθ = θ .- mean_θ
213+
cr = cor(CA.getdata(residθ));
214+
i_sites = [1,2,3]
215+
tmp = CA.ComponentArray(collect(axes(θ[:,1],1)), CA.getaxes(θ[:,1]));
216+
#ax = map(x -> axes(x,1), get_hybridproblem_par_templates(probo; scenario = scenf))
217+
is = vcat(tmp.P, vec(tmp.Ms[:,i_sites]))
218+
cr[is,is]
219+
end
220+
202221
end;
203222
# does not work with general Bijector:
204223
# @testset "HybridPosteriorSolver also f on gpu" begin

0 commit comments

Comments
 (0)