Skip to content

Commit cb53223

Browse files
committed
test scaling variance
update HybridProblem for narrower types
1 parent 34924cb commit cb53223

File tree

2 files changed

+38
-31
lines changed

2 files changed

+38
-31
lines changed

src/HybridProblem.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
struct HybridProblem <: AbstractHybridProblem
2-
θP::Any
3-
θM::Any
2+
θP::CA.ComponentVector
3+
θM::CA.ComponentVector
44
f_batch::Any
55
f_allsites::Any
6-
g::Any
7-
ϕg::Any
8-
ϕunc::Any
9-
priors::Any
10-
py::Any
11-
transM::Any
12-
transP::Any
13-
cor_ends::Any # = (P=(1,),M=(1,))
14-
train_dataloader::Any
6+
g::AbstractModelApplicator
7+
ϕg::Any # depends on framework
8+
ϕunc::CA.ComponentVector
9+
priors::AbstractDict
10+
py::Any # any callable
11+
transM::Stacked
12+
transP::Stacked
13+
cor_ends::@NamedTuple{P::Vector{Int}, M::Vector{Int}} # = (P=(1,),M=(1,))
14+
train_dataloader::MLUtils.DataLoader
1515
n_covar::Int
1616
n_site::Int
1717
n_batch::Int
18-
pbm_covars::NTuple
18+
pbm_covars::NTuple{_N, Symbol} where _N
1919
#inner constructor to constrain the types
2020
function HybridProblem(
2121
θP::CA.ComponentVector, θM::CA.ComponentVector,
@@ -24,9 +24,9 @@ struct HybridProblem <: AbstractHybridProblem
2424
f_batch::Function,
2525
f_allsites::Function,
2626
priors::AbstractDict,
27-
py::Function,
28-
transM::Union{Function, Bijectors.Transform},
29-
transP::Union{Function, Bijectors.Transform},
27+
py,
28+
transM::Stacked,
29+
transP::Stacked,
3030
# return a function that constructs the trainloader based on n_batch
3131
train_dataloader::MLUtils.DataLoader,
3232
n_covar::Int,

test/test_elbo.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ test_scenario = (scenario) -> begin
4949
@test _y_o == y_o
5050
end; tmpf()
5151

52-
# TODO
53-
#g, ϕg0 = @inferred get_hybridproblem_MLapplicator(probc; scenario);
52+
# prediction by g(ϕg, XM) does not correspond to θMs_true, randomly initialized
53+
# only the magnitude is there because of NormalScaling and prior
5454
g, ϕg0 = get_hybridproblem_MLapplicator(probc; scenario)
5555
f = get_hybridproblem_PBmodel(probc; scenario, use_all_sites=false)
5656
f_pred = get_hybridproblem_PBmodel(probc; scenario, use_all_sites=true)
@@ -123,21 +123,21 @@ test_scenario = (scenario) -> begin
123123
ϕunc_true = copy(probc.ϕunc)
124124
sd_ζP_true = [0.2,20]
125125
sd_ζMs_a_true = [0.1,2] # sd at_variance at θ==0
126-
logσ2_ζMs_b_true = [-3.0,+0.2] # slope of log_variance with θ
126+
logσ2_ζMs_b_true = [-0.3,+0.2] # slope of log_variance with θ
127127
ρsP_true = [+0.8]
128128
ρsM_true = [-0.6]
129129

130130
ϕunc_true.logσ2_ζP = (log abs2).(sd_ζP_true)
131131
ϕunc_true.coef_logσ2_ζMs[1,:] = (log abs2).(sd_ζMs_a_true)
132-
#ϕunc_true.coef_logσ2_ζMs[2,:] = logσ2_ζMs_b_true # predicted means do not scale
132+
ϕunc_true.coef_logσ2_ζMs[2,:] = logσ2_ζMs_b_true
133133
ϕunc_true.ρsP = ρsP_true
134134
ϕunc_true.ρsM = ρsM_true
135135

136136
probd = CP.update(probc; ϕunc=ϕunc_true);
137137
= vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc)
138138
#hcat(ϕ_ini, ϕ, _ϕ)[1:4,:]
139139
#hcat(ϕ_ini, ϕ, _ϕ)[(end-20):end,:]
140-
n_predict = 8000
140+
n_predict = 80000
141141
xM_batch = xM[:, 1:n_batch]
142142
_ζsP, _ζsMs, _σ = @inferred (
143143
# @descend_code_warntype (
@@ -146,30 +146,37 @@ test_scenario = (scenario) -> begin
146146
n_MC = n_predict, cor_ends, pbm_covar_indices,
147147
int_unc=interpreters.unc, int_μP_ϕg_unc=interpreters.μP_ϕg_unc)
148148
)
149-
meanζMs_true = g(xM_batch, probc.ϕg)' # have been generated with no scaling
150-
function test_distζ(_ζsP, _ζsMs, ϕunc_true, meanζMs_true)
149+
ζMs_g = g(xM_batch, probc.ϕg)' # have been generated with no scaling
150+
function test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g)
151151
mP = mean(_ζsP; dims=2)
152152
residP = _ζsP .- mP
153153
sdP = vec(std(residP; dims=2))
154154
_sd_ζP_true = sqrt.(exp.(ϕunc_true.logσ2_ζP))
155155
@test isapprox(sdP, _sd_ζP_true; rtol=0.05)
156156
mMs = mean(_ζsMs; dims=3)[:,:,1]
157-
#hcat(mMs, meanζMs_true)
157+
hcat(mMs, ζMs_g)
158+
# @usingany UnicodePlots
159+
#scatterplot(ζMs_g[:,1], mMs[:,1])
160+
#scatterplot(ζMs_g[:,2], mMs[:,2])
161+
@test cor(ζMs_g[:,1], mMs[:,1]) > 0.9
162+
@test cor(ζMs_g[:,2], mMs[:,2]) > 0.8
158163
map(axes(mMs,2)) do ipar
159164
#@show ipar
160-
@test isapprox(mMs[:,ipar], meanζMs_true[:,ipar]; rtol=0.1)
165+
@test isapprox(mMs[:,ipar], ζMs_g[:,ipar]; rtol=0.1)
161166
end
162167
#ζMs_true = stack(map(inverse(transM), eachcol(CA.getdata(θMs_true[:,1:n_batch]))))'
163168
residMs = _ζsMs .- mMs
164169
sdMs = std(residMs; dims=3)[:,:,1]
165170
# (_a,_b), mMi = first(zip(
166171
# eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(mMs)))
167172
_sd_ζMs_true = stack(map(
168-
eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(mMs)) do (_a,_b), mMi
173+
eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(ζMs_g)) do (_a,_b), mMi
174+
#eachcol(ϕunc_true.coef_logσ2_ζMs), eachcol(mMs)) do (_a,_b), mMi
169175
logσ2_ζM = _a .+ mMi .* _b
170176
sqrt.(exp.(logσ2_ζM))
171177
end)
172178
#ipar = 2
179+
#ipar = 1
173180
map(axes(sdMs,2)) do ipar
174181
#@show ipar
175182
hcat(sdMs[:,ipar], _sd_ζMs_true[:,ipar])
@@ -184,17 +191,17 @@ test_scenario = (scenario) -> begin
184191
cor_PMs = cor(residPMst')
185192
@test cor_PMs[1,2] ≈ ρsP_true[1] atol=0.2
186193
@test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.2)) # no correlations P,M
187-
@test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.1
194+
@test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.2
188195
@test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.2)) # no correlations M1, M2
189-
@test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.1
196+
@test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.2
190197
@test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.2)) # no correlations M1, M2
191198
end
192-
test_distζ(_ζsP, _ζsMs, ϕunc_true, meanζMs_true)
199+
test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g)
193200
@testset "predict_hvi check sd" begin
194201
# test if uncertainty and reshaping is propagated
195202
# here inverse the predicted θs and then test distribution
196203
probcu = CP.update(probc, ϕunc=ϕunc_true);
197-
n_sample_pred = 8000
204+
n_sample_pred = 24_000
198205
(; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probcu; scenario, n_sample_pred);
199206
#size(_ζsMs), size(θsMs)
200207
#size(_ζsP), size(θsP)
@@ -204,8 +211,8 @@ test_scenario = (scenario) -> begin
204211
_ζsMs2 = stack(map(eachslice(θsMs; dims=3)) do _θMs
205212
int_minvM(_θMs)
206213
end)
207-
meanζMs_true2 = g(xM, probcu.ϕg)' # have been generated with no scaling
208-
test_distζ(_ζsP2, _ζsMs2, ϕunc_true, meanζMs_true2)
214+
ζMs_g2 = g(xM, probcu.ϕg)' # have been generated with no scaling
215+
test_distζ(_ζsP2, _ζsMs2, ϕunc_true, ζMs_g2)
209216
end;
210217
end;
211218
end # if covar in scenario

0 commit comments

Comments
 (0)