@@ -49,8 +49,8 @@ test_scenario = (scenario) -> begin
4949 @test _y_o == y_o
5050 end ; tmpf()
5151
52- # TODO
53- # g, ϕg0 = @inferred get_hybridproblem_MLapplicator(probc; scenario);
52+ # prediction by g(ϕg, XM) does not correspond to θMs_true, randomly initialized
53+ # only the magnitude is there because of NormalScaling and prior
5454 g, ϕg0 = get_hybridproblem_MLapplicator(probc; scenario)
5555 f = get_hybridproblem_PBmodel(probc; scenario, use_all_sites= false )
5656 f_pred = get_hybridproblem_PBmodel(probc; scenario, use_all_sites= true )
@@ -123,21 +123,21 @@ test_scenario = (scenario) -> begin
123123 ϕunc_true = copy(probc. ϕunc)
124124 sd_ζP_true = [0.2 ,20 ]
125125 sd_ζMs_a_true = [0.1 ,2 ] # sd at_variance at θ==0
126- logσ2_ζMs_b_true = [- 3.0 ,+ 0.2 ] # slope of log_variance with θ
126+ logσ2_ζMs_b_true = [- 0.3 ,+ 0.2 ] # slope of log_variance with θ
127127 ρsP_true = [+ 0.8 ]
128128 ρsM_true = [- 0.6 ]
129129
130130 ϕunc_true. logσ2_ζP = (log ∘ abs2). (sd_ζP_true)
131131 ϕunc_true. coef_logσ2_ζMs[1 ,:] = (log ∘ abs2). (sd_ζMs_a_true)
132- # ϕunc_true.coef_logσ2_ζMs[2,:] = logσ2_ζMs_b_true # predicted means do not scale
132+ ϕunc_true. coef_logσ2_ζMs[2 ,:] = logσ2_ζMs_b_true
133133 ϕunc_true. ρsP = ρsP_true
134134 ϕunc_true. ρsM = ρsM_true
135135
136136 probd = CP. update(probc; ϕunc= ϕunc_true);
137137 _ϕ = vcat(ϕ_ini. μP, probc. ϕg, probd. ϕunc)
138138 # hcat(ϕ_ini, ϕ, _ϕ)[1:4,:]
139139 # hcat(ϕ_ini, ϕ, _ϕ)[(end-20):end,:]
140- n_predict = 8000
140+ n_predict = 80000
141141 xM_batch = xM[:, 1 : n_batch]
142142 _ζsP, _ζsMs, _σ = @inferred (
143143 # @descend_code_warntype (
@@ -146,30 +146,37 @@ test_scenario = (scenario) -> begin
146146 n_MC = n_predict, cor_ends, pbm_covar_indices,
147147 int_unc= interpreters. unc, int_μP_ϕg_unc= interpreters. μP_ϕg_unc)
148148 )
149- meanζMs_true = g(xM_batch, probc. ϕg)' # have been generated with no scaling
150- function test_distζ(_ζsP, _ζsMs, ϕunc_true, meanζMs_true )
149+ ζMs_g = g(xM_batch, probc. ϕg)' # have been generated with no scaling
150+ function test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g )
151151 mP = mean(_ζsP; dims= 2 )
152152 residP = _ζsP .- mP
153153 sdP = vec(std(residP; dims= 2 ))
154154 _sd_ζP_true = sqrt.(exp.(ϕunc_true. logσ2_ζP))
155155 @test isapprox(sdP, _sd_ζP_true; rtol= 0.05 )
156156 mMs = mean(_ζsMs; dims= 3 )[:,:,1 ]
157- # hcat(mMs, meanζMs_true)
157+ hcat(mMs, ζMs_g)
158+ # @usingany UnicodePlots
159+ # scatterplot(ζMs_g[:,1], mMs[:,1])
160+ # scatterplot(ζMs_g[:,2], mMs[:,2])
161+ @test cor(ζMs_g[:,1 ], mMs[:,1 ]) > 0.9
162+ @test cor(ζMs_g[:,2 ], mMs[:,2 ]) > 0.8
158163 map(axes(mMs,2 )) do ipar
159164 # @show ipar
160- @test isapprox(mMs[:,ipar], meanζMs_true [:,ipar]; rtol= 0.1 )
165+ @test isapprox(mMs[:,ipar], ζMs_g [:,ipar]; rtol= 0.1 )
161166 end
162167 # ζMs_true = stack(map(inverse(transM), eachcol(CA.getdata(θMs_true[:,1:n_batch]))))'
163168 residMs = _ζsMs .- mMs
164169 sdMs = std(residMs; dims= 3 )[:,:,1 ]
165170 # (_a,_b), mMi = first(zip(
166171 # eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(mMs)))
167172 _sd_ζMs_true = stack(map(
168- eachcol(ϕunc_true. coef_logσ2_ζMs), eachcol(mMs)) do (_a,_b), mMi
173+ eachcol(ϕunc_true. coef_logσ2_ζMs), eachcol(ζMs_g)) do (_a,_b), mMi
174+ # eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(mMs)) do (_a,_b), mMi
169175 logσ2_ζM = _a .+ mMi .* _b
170176 sqrt.(exp.(logσ2_ζM))
171177 end )
172178 # ipar = 2
179+ # ipar = 1
173180 map(axes(sdMs,2 )) do ipar
174181 # @show ipar
175182 hcat(sdMs[:,ipar], _sd_ζMs_true[:,ipar])
@@ -184,17 +191,17 @@ test_scenario = (scenario) -> begin
184191 cor_PMs = cor(residPMst' )
185192 @test cor_PMs[1,2] ≈ ρsP_true[1] atol=0.2
186193 @test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.2)) # no correlations P,M
187- @test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.1
194+ @test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.2
188195 @test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.2)) # no correlations M1, M2
189- @test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.1
196+ @test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.2
190197 @test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.2)) # no correlations M1, M2
191198 end
192- test_distζ(_ζsP, _ζsMs, ϕunc_true, meanζMs_true )
199+ test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g )
193200 @testset "predict_hvi check sd" begin
194201 # test if uncertainty and reshaping is propagated
195202 # here inverse the predicted θs and then test distribution
196203 probcu = CP.update(probc, ϕunc=ϕunc_true);
197- n_sample_pred = 8000
204+ n_sample_pred = 24_000
198205 (; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probcu; scenario, n_sample_pred);
199206 #size(_ζsMs), size(θsMs)
200207 #size(_ζsP), size(θsP)
@@ -204,8 +211,8 @@ test_scenario = (scenario) -> begin
204211 _ζsMs2 = stack(map(eachslice(θsMs; dims=3)) do _θMs
205212 int_minvM(_θMs)
206213 end)
207- meanζMs_true2 = g(xM, probcu.ϕg)' # have been generated with no scaling
208- test_distζ(_ζsP2, _ζsMs2, ϕunc_true, meanζMs_true2 )
214+ ζMs_g2 = g(xM, probcu.ϕg)' # have been generated with no scaling
215+ test_distζ(_ζsP2, _ζsMs2, ϕunc_true, ζMs_g2 )
209216 end ;
210217 end ;
211218 end # if covar in scenario
0 commit comments