@@ -6,6 +6,11 @@ const θall = vcat(θP, θM)
66
77const θP_nor0 = θP[(:K2 ,)]
88
9+ const xP_S1 = Float32[0.5 , 0.5 , 0.5 , 0.5 , 0.4 , 0.3 , 0.2 , 0.1 ]
10+ const xP_S2 = Float32[1.0 , 3.0 , 4.0 , 5.0 , 5.0 , 5.0 , 5.0 , 5.0 ]
11+
12+ int_xP1 = ComponentArrayInterpreter (CA. ComponentVector (S1= xP_S1, S2= xP_S2))
13+
914# const transP = elementwise(exp)
1015# const transM = elementwise(exp)
1116
@@ -164,20 +169,29 @@ function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::NTuple = ()
164169 intθ1, θFix1 = setup_PBMpar_interpreter (par_templates. θP, par_templates. θM, θall)
165170 θFix = repeat (θFix1' , n_site_batch)
166171 intθ = get_concrete (ComponentArrayInterpreter ((n_site_batch,), intθ1))
172+ # int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1)
167173 isP = repeat (axes (par_templates. θP,1 )' , n_site_batch)
168- let θFix = θFix, θFix_dev = gdev (θFix), intθ = get_concrete (intθ), isP= isP, n_site_batch= n_site_batch
174+ let θFix = θFix, θFix_dev = gdev (θFix), intθ = get_concrete (intθ), isP= isP,
175+ n_site_batch= n_site_batch,
176+ # int_xPb=get_concrete(int_xPb),
177+ pos_xP = get_positions (int_xP1)
169178 function f_doubleMM_with_global (θP:: AbstractVector , θMs:: AbstractMatrix , xP)
170- @assert length (xP) == n_site_batch
179+ @assert size (xP, 2 ) == n_site_batch
171180 @assert size (θMs,2 ) == n_site_batch
172- # convert vector of tuples to tuple of matricesByRows
173- # need to supply xP as vectorOfTuples to work with DataLoader
174- # k = first(keys(xP[1]))
175- xPM = (; zip (keys (xP[1 ]), map (keys (xP[1 ])) do k
176- # stack(map(r -> r[k], xP))'
177- stack (map (r -> r[k], xP); dims = 1 )
178- end )... )
181+ # # convert vector of tuples to tuple of matricesByRows
182+ # # need to supply xP as vectorOfTuples to work with DataLoader
183+ # # k = first(keys(xP[1]))
184+ # xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
185+ # #stack(map(r -> r[k], xP))'
186+ # stack(map(r -> r[k], xP); dims = 1)
187+ # end)...)
179188 # xPM = map(transpose, xPM1)
189+ # xPc = int_xPb(CA.getdata(xP))
190+ # xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
180191 # make sure the same order of columns as in intθ
192+ # reshape big matrix into NamedTuple of drivers S1 and S2
193+ # for broadcasting need sites in rows
194+ xPM = map (p -> CA. getdata (xP[p,:])' , pos_xP)
181195 θFixd = (θP isa GPUArraysCore. AbstractGPUVector) ? θFix_dev : θFix
182196 θ = hcat (CA. getdata (θP[isP]), CA. getdata (θMs)' , θFixd)
183197 pred_sites = f_doubleMM (θ, xPM, intθ)'
202216# return Float32
203217# end
204218
205- const xP_S1 = Float32[0.5 , 0.5 , 0.5 , 0.5 , 0.4 , 0.3 , 0.2 , 0.1 ]
206- const xP_S2 = Float32[1.0 , 3.0 , 4.0 , 5.0 , 5.0 , 5.0 , 5.0 , 5.0 ]
207219
208220# two observations more?
209221# const xP_S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.1]
@@ -242,7 +254,10 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
242254 # normalize to be distributed around the prescribed true values
243255 θMs_true = int_θMs_sites (scale_centered_at (θMs_true0, θM, FloatType (0.1 )))
244256 f = get_hybridproblem_PBmodel (prob; scenario, gdev= identity, use_all_sites = true )
245- xP = fill ((; S1 = xP_S1, S2 = xP_S2), n_site)
257+ # xP = fill((; S1 = xP_S1, S2 = xP_S2), n_site)
258+ int_xPn = ComponentArrayInterpreter (int_xP1, (n_site,))
259+ xP = int_xPn (vcat (repeat (xP_S1,1 ,n_site),repeat (xP_S2,1 ,n_site)))
260+ # xP[:S1,:]
246261 θP = par_templates. θP
247262 y_global_true, y_true = f (θP, θMs_true, xP)
248263 σ_o = FloatType (0.01 )
0 commit comments