@@ -181,55 +181,43 @@ end
181181# end
182182# end
183183
184- function HVI. get_hybridproblem_PBmodel (prob:: DoubleMMCase ; scenario:: Val{scen} ,
185- use_all_sites = false ,
186- gdev = :f_on_gpu ∈ HVI. _val_value (scenario) ? gpu_device () : identity
187- ) where {scen}
188- n_site, n_batch = get_hybridproblem_n_site_and_batch (prob; scenario)
189- n_site_batch = use_all_sites ? n_site : n_batch
190- # fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
191- par_templates = get_hybridproblem_par_templates (prob; scenario)
192- intθ1, θFix1 = setup_PBMpar_interpreter (par_templates. θP, par_templates. θM, θall)
193- θFix = repeat (θFix1' , n_site_batch)
194- intθ = get_concrete (ComponentArrayInterpreter ((n_site_batch,), intθ1))
195- # int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1)
196- isP = repeat (axes (par_templates. θP, 1 )' , n_site_batch)
197- let θFix = θFix, θFix_dev = gdev (θFix), intθ = get_concrete (intθ), isP = isP,
198- n_site_batch = n_site_batch,
199- # int_xPb=get_concrete(int_xPb),
200- pos_xP = get_positions (int_xP1)
184+ # defining the PBmodel as a closure with let leads to problems of JLD2 reloading
185+ # Define all the variables additional to the ones passed curing the call by
186+ # a dedicated Closure object and define the PBmodel as a callable
187+ struct DoubleMMCaller{CLT}
188+ cl:: CLT
189+ end
201190
202- function f_doubleMM_with_global (θP:: AbstractVector , θMs:: AbstractMatrix , xP)
203- @assert size (xP, 2 ) == n_site_batch
204- @assert size (θMs, 1 ) == n_site_batch
205- # # convert vector of tuples to tuple of matricesByRows
206- # # need to supply xP as vectorOfTuples to work with DataLoader
207- # # k = first(keys(xP[1]))
208- # xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
209- # #stack(map(r -> r[k], xP))'
210- # stack(map(r -> r[k], xP); dims = 1)
211- # end)...)
212- # xPM = map(transpose, xPM1)
213- # xPc = int_xPb(CA.getdata(xP))
214- # xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
215- # make sure the same order of columns as in intθ
216- # reshape big matrix into NamedTuple of drivers S1 and S2
217- # for broadcasting need sites in rows
218- # xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)
219- xPM = map (p -> CA. getdata (xP)' [:, p], pos_xP)
220- θFixd = (θP isa GPUArraysCore. AbstractGPUVector) ? θFix_dev : θFix
221- θ = hcat (CA. getdata (θP[isP]), CA. getdata (θMs), θFixd)
222- pred_sites = f_doubleMM (θ, xPM; intθ)'
223- pred_global = eltype (pred_sites)[]
224- return pred_global, pred_sites
225- end
226- # function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
227- # # TODO
228- # pred_sites = f_doubleMM(θMs, θP, θFix, xP, intθ)
229- # pred_global = eltype(pred_sites)[]
230- # return pred_global, pred_sites
231- # end
232- end
191+ function HVI. get_hybridproblem_PBmodel (prob:: DoubleMMCase ; scenario, kwargs... )
192+ # θall defined in this module above
193+ cl = HVI. PBmodelClosure (prob; scenario, θall, int_xP1, kwargs... )
194+ return DoubleMMCaller {typeof(cl)} (cl)
195+ end
196+
197+ function (caller:: DoubleMMCaller )(θP:: AbstractVector , θMs:: AbstractMatrix , xP)
198+ cl = caller. cl
199+ @assert size (xP, 2 ) == cl. n_site_batch
200+ @assert size (θMs, 1 ) == cl. n_site_batch
201+ # # convert vector of tuples to tuple of matricesByRows
202+ # # need to supply xP as vectorOfTuples to work with DataLoader
203+ # # k = first(keys(xP[1]))
204+ # xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
205+ # #stack(map(r -> r[k], xP))'
206+ # stack(map(r -> r[k], xP); dims = 1)
207+ # end)...)
208+ # xPM = map(transpose, xPM1)
209+ # xPc = int_xPb(CA.getdata(xP))
210+ # xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
211+ # make sure the same order of columns as in intθ
212+ # reshape big matrix into NamedTuple of drivers S1 and S2
213+ # for broadcasting need sites in rows
214+ # xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)
215+ xPM = map (p -> CA. getdata (xP)' [:, p], cl. pos_xP)
216+ θFixd = (θP isa GPUArraysCore. AbstractGPUVector) ? cl. θFix_dev : cl. θFix
217+ θ = hcat (CA. getdata (θP[cl. isP]), CA. getdata (θMs), θFixd)
218+ pred_sites = f_doubleMM (θ, xPM; cl. intθ)'
219+ pred_global = eltype (pred_sites)[]
220+ return pred_global, pred_sites
233221end
234222
235223function HVI. get_hybridproblem_neg_logden_obs (:: DoubleMMCase ; scenario:: Val )
@@ -284,8 +272,7 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
284272 xP = int_xP_sites (vcat (repeat (xP_S1, 1 , n_site), repeat (xP_S2, 1 , n_site)))
285273 # xP[:S1,:]
286274 θP = par_templates. θP
287- # θint = ComponentArrayInterpreter( (size(θMs_true,2),), CA.getaxes(vcat(θP, θMs_true[:,1])))
288- y_global_true, y_true = f (θP, θMs_true' , xP)
275+ y_global_true, y_true = f (θP, θMs_true' , xP)
289276 σ_o = FloatType (0.01 )
290277 # σ_o = FloatType(0.002)
291278 logσ2_o = FloatType (2 ) .* log .(σ_o)
0 commit comments