@@ -13,10 +13,10 @@ const transMS = Stacked(elementwise(identity), elementwise(exp))
1313
1414const int_θdoubleMM = ComponentArrayInterpreter (flatten1 (CA. ComponentVector (; θP, θM)))
1515
16- function f_doubleMM (θ:: AbstractVector , x)
16+ function f_doubleMM (θ:: AbstractVector , x, intθ )
1717 # extract parameters not depending on order, i.e whether they are in θP or θM
1818 y = GPUArraysCore. allowscalar () do
19- θc = int_θdoubleMM (θ)
19+ θc = intθ (θ)
2020 # using ComponentArrays: ComponentArrays as CA
2121 # r0, r1, K1, K2 = θc[(:r0, :r1, :K1, :K2)] # does not work on Zygote+GPU
2222 r0 = θc[:r0 ]
3030
3131function HVI. get_hybridproblem_par_templates (:: DoubleMMCase ; scenario:: NTuple = ())
3232 if (:omit_r0 ∈ scenario)
33+ # return ((; θP = θP_nor0, θM, θf = θP[(:K2r)]))
3334 return ((; θP = θP_nor0, θM))
3435 end
36+ # (; θP, θM, θf = eltype(θP)[])
3537 (; θP, θM)
3638end
3739
@@ -74,11 +76,10 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = ()
7476 )
7577 # fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
7678 par_templates = get_hybridproblem_par_templates (prob; scenario)
77- keys_fixed = ((k for k in keys (θall) if
78- (k ∉ keys (par_templates. θP)) & (k ∉ keys (par_templates. θM))). .. ,)
79- let θFix = gdev (θall[keys_fixed])
79+ intθ, θFix = setup_PBMpar_interpreter (par_templates. θP, par_templates. θM, θall)
80+ let θFix = gdev (θFix), intθ = get_concrete (intθ)
8081 function f_doubleMM_with_global (θP:: AbstractVector , θMs:: AbstractMatrix , x)
81- pred_sites = applyf (f_doubleMM, θMs, θP, θFix, x)
82+ pred_sites = applyf (f_doubleMM, θMs, θP, θFix, x, intθ )
8283 pred_global = eltype (pred_sites)[]
8384 return pred_global, pred_sites
8485 end
@@ -101,7 +102,12 @@ const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0]
101102# const xP_S2 = Float32[1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
102103
103104HVI. get_hybridproblem_n_covar (prob:: DoubleMMCase ; scenario) = 5
104- HVI. get_hybridproblem_n_site (prob:: DoubleMMCase ; scenario) = 800
105+ function HVI. get_hybridproblem_n_site (prob:: DoubleMMCase ; scenario)
106+ if (:few_sites ∈ scenario)
107+ return (100 )
108+ end
109+ 800
110+ end
105111
106112function HVI. get_hybridproblem_train_dataloader (prob:: DoubleMMCase ; scenario = (),
107113 n_batch, rng:: AbstractRNG = StableRNG (111 ), kwargs...
0 commit comments