Skip to content

Commit cc7a7b2

Browse files
authored
rename AbstractHybridCase to AbstractHybridProblem (#15)
1 parent 357f1ec commit cc7a7b2

14 files changed

+131
-131
lines changed

dev/doubleMM.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,24 @@ using OptimizationOptimisers
1616
using Bijectors
1717
using UnicodePlots
1818

19-
const case = DoubleMM.DoubleMMCase()
19+
const prob = DoubleMM.DoubleMMCase()
2020
scenario = (:default,)
2121
rng = StableRNG(111)
2222

23-
par_templates = get_hybridcase_par_templates(case; scenario)
23+
par_templates = get_hybridproblem_par_templates(prob; scenario)
2424

25-
#n_covar = get_hybridcase_n_covar(case; scenario)
26-
#, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
25+
#n_covar = get_hybridproblem_n_covar(prob; scenario)
26+
#, n_batch, n_θM, n_θP) = get_hybridproblem_sizes(prob; scenario)
2727

2828
(; xM, n_site, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
29-
) = gen_hybridcase_synthetic(rng, case; scenario);
29+
) = gen_hybridcase_synthetic(rng, prob; scenario);
3030

3131
n_covar = size(xM,1)
3232

3333

3434
#----- fit g to θMs_true
35-
g, ϕg0 = get_hybridcase_MLapplicator(case; scenario);
36-
(; transP, transM) = get_hybridcase_transforms(case; scenario)
35+
g, ϕg0 = get_hybridproblem_MLapplicator(prob; scenario);
36+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
3737

3838
function loss_g(ϕg, x, g, transM)
3939
ζMs = g(x, ϕg) # predict the log of the parameters
@@ -52,8 +52,8 @@ res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), max
5252
l1, θMs_pred = loss_g(ϕg_opt1, xM, g, transM)
5353
scatterplot(vec(θMs_true), vec(θMs_pred))
5454

55-
f = get_hybridcase_PBmodel(case; scenario)
56-
py = get_hybridcase_neg_logden_obs(case; scenario)
55+
f = get_hybridproblem_PBmodel(prob; scenario)
56+
py = get_hybridproblem_neg_logden_obs(prob; scenario)
5757

5858
#----------- fit g and θP to y_o
5959
() -> begin
@@ -85,8 +85,8 @@ end
8585

8686
#---------- HVI
8787
n_MC = 3
88-
(; transP, transM) = get_hybridcase_transforms(case; scenario)
89-
FT = get_hybridcase_float_type(case; scenario)
88+
(; transP, transM) = get_hybridproblem_transforms(prob; scenario)
89+
FT = get_hybridproblem_float_type(prob; scenario)
9090

9191
(; ϕ, transPMs_batch, interpreters, get_transPMs, get_ca_int_PMs) = init_hybrid_params(
9292
θP_true, θMs_true[:, 1], ϕg_opt1, n_batch; transP, transM);
@@ -167,7 +167,7 @@ mean_σ_o_MC = 0.006042
167167
ϕ = CA.getdata(ϕ_ini) |> Flux.gpu;
168168
xM_gpu = xM |> Flux.gpu;
169169
scenario_flux = (scenario..., :use_Flux)
170-
g_flux, _ = get_hybridcase_MLapplicator(case; scenario = scenario_flux);
170+
g_flux, _ = get_hybridproblem_MLapplicator(prob; scenario = scenario_flux);
171171

172172
# otpimize using LUX
173173
() -> begin
@@ -205,7 +205,7 @@ gr = Zygote.gradient(fcost,
205205
gr_c = CA.ComponentArray(gr[1] |> Flux.cpu, CA.getaxes(ϕ_ini)...)
206206

207207
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
208-
#train_loader = get_hybridcase_train_dataloader(case, rng; scenario = (scenario..., :use_Flux))
208+
#train_loader = get_hybridproblem_train_dataloader(prob, rng; scenario = (scenario..., :use_Flux))
209209

210210
optf = Optimization.OptimizationFunction(
211211
(ϕ, data) -> begin

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ end
3636
# end
3737

3838
function HVI.construct_3layer_MLApplicator(
39-
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:Flux};
39+
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:Flux};
4040
scenario::NTuple = ())
41-
(;θM) = get_hybridcase_par_templates(case; scenario)
41+
(;θM) = get_hybridproblem_par_templates(prob; scenario)
4242
n_out = length(θM)
43-
n_covar = get_hybridcase_n_covar(case; scenario)
44-
#(; n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
45-
float_type = get_hybridcase_float_type(case; scenario)
43+
n_covar = get_hybridproblem_n_covar(prob; scenario)
44+
#(; n_covar, n_θM) = get_hybridproblem_sizes(prob; scenario)
45+
float_type = get_hybridproblem_float_type(prob; scenario)
4646
is_using_dropout = :use_dropout scenario
4747
is_using_dropout && error("dropout scenario not supported with Flux yet.")
4848
g_chain = Flux.Chain(

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ end
2020
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
2121

2222
function HVI.construct_3layer_MLApplicator(
23-
rng::AbstractRNG, case::HVI.AbstractHybridCase, ::Val{:SimpleChains};
23+
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:SimpleChains};
2424
scenario::NTuple = ())
25-
n_covar = get_hybridcase_n_covar(case; scenario)
26-
FloatType = get_hybridcase_float_type(case; scenario)
27-
(;θM) = get_hybridcase_par_templates(case; scenario)
25+
n_covar = get_hybridproblem_n_covar(prob; scenario)
26+
FloatType = get_hybridproblem_float_type(prob; scenario)
27+
(;θM) = get_hybridproblem_par_templates(prob; scenario)
2828
n_out = length(θM)
2929
is_using_dropout = :use_dropout scenario
3030
g_chain = if is_using_dropout
Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,40 @@
22
Type to dispatch constructing data and network structures
33
for different cases of hybrid problem setups
44
5-
For a specific case, provide functions that specify details
6-
- `get_hybridcase_MLapplicator`
7-
- `get_hybridcase_PBmodel`
8-
- `get_hybridcase_neg_logden_obs`
9-
- `get_hybridcase_par_templates`
10-
- `get_hybridcase_transforms`
11-
- `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
5+
For a specific prob, provide functions that specify details
6+
- `get_hybridproblem_MLapplicator`
7+
- `get_hybridproblem_PBmodel`
8+
- `get_hybridproblem_neg_logden_obs`
9+
- `get_hybridproblem_par_templates`
10+
- `get_hybridproblem_transforms`
11+
- `get_hybridproblem_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
1212
optionally
1313
- `gen_hybridcase_synthetic`
14-
- `get_hybridcase_n_covar` (defaults to number of rows in xM in train_dataloader )
15-
- `get_hybridcase_float_type` (defaults to `eltype(θM)`)
16-
- `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
14+
- `get_hybridproblem_n_covar` (defaults to number of rows in xM in train_dataloader )
15+
- `get_hybridproblem_float_type` (defaults to `eltype(θM)`)
16+
- `get_hybridproblem_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
1717
"""
18-
abstract type AbstractHybridCase end;
18+
abstract type AbstractHybridProblem end;
1919

2020

2121
"""
22-
get_hybridcase_MLapplicator([rng::AbstractRNG,] ::AbstractHybridCase; scenario=())
22+
get_hybridproblem_MLapplicator([rng::AbstractRNG,] ::AbstractHybridProblem; scenario=())
2323
24-
Construct the machine learning model fro given problem case and ML-Framework and
24+
Construct the machine learning model fro given problem prob and ML-Framework and
2525
scenario.
2626
2727
returns a Tuple of
2828
- AbstractModelApplicator
2929
- initial parameter vector
3030
"""
31-
function get_hybridcase_MLapplicator end
31+
function get_hybridproblem_MLapplicator end
3232

33-
function get_hybridcase_MLapplicator(case::AbstractHybridCase; scenario=())
34-
get_hybridcase_MLapplicator(Random.default_rng(), case; scenario)
33+
function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario=())
34+
get_hybridproblem_MLapplicator(Random.default_rng(), prob; scenario)
3535
end
3636

3737
"""
38-
get_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=())
38+
get_hybridproblem_PBmodel(::AbstractHybridProblem; scenario::NTuple=())
3939
4040
Construct the process-based model function
4141
`f(θP::AbstractVector, θMs::AbstractMatrix, x) -> (AbstractVector, AbstractMatrix)`
@@ -48,59 +48,59 @@ returns a tuple of predictions with components
4848
- first, those that are constant across sites
4949
- second, those that vary across sites, with a column for each site
5050
"""
51-
function get_hybridcase_PBmodel end
51+
function get_hybridproblem_PBmodel end
5252

5353
"""
54-
get_hybridcase_neg_logden_obs(::AbstractHybridCase; scenario)
54+
get_hybridproblem_neg_logden_obs(::AbstractHybridProblem; scenario)
5555
5656
Provide a `function(y_obs, ypred) -> Real` that computes the negative logdensity
5757
of the observations, given the predictions.
5858
"""
59-
function get_hybridcase_neg_logden_obs end
59+
function get_hybridproblem_neg_logden_obs end
6060

6161

6262
"""
63-
get_hybridcase_par_templates(::AbstractHybridCase; scenario)
63+
get_hybridproblem_par_templates(::AbstractHybridProblem; scenario)
6464
6565
Provide tuple of templates of ComponentVectors `θP` and `θM`.
6666
"""
67-
function get_hybridcase_par_templates end
67+
function get_hybridproblem_par_templates end
6868

6969

7070
"""
71-
get_hybridcase_transforms(::AbstractHybridCase; scenario)
71+
get_hybridproblem_transforms(::AbstractHybridProblem; scenario)
7272
7373
Return a NamedTupe of
7474
- `transP`: Bijectors.Transform for the global PBM parameters, θP
7575
- `transM`: Bijectors.Transform for the single-site PBM parameters, θM
7676
"""
77-
function get_hybridcase_transforms end
77+
function get_hybridproblem_transforms end
7878

7979
# """
80-
# get_hybridcase_par_templates(::AbstractHybridCase; scenario)
80+
# get_hybridproblem_par_templates(::AbstractHybridProblem; scenario)
8181
# Provide a NamedTuple of number of
8282
# - n_covar: covariates xM
8383
# - n_site: all sites in the data
8484
# - n_batch: sites in one minibatch during fitting
8585
# - n_θM, n_θP: entries in parameter vectors
8686
# """
87-
# function get_hybridcase_sizes end
87+
# function get_hybridproblem_sizes end
8888

8989
"""
90-
get_hybridcase_n_covar(::AbstractHybridCase; scenario)
90+
get_hybridproblem_n_covar(::AbstractHybridProblem; scenario)
9191
9292
Provide the number of covariates. Default returns the number of rows in `xM` from
93-
`get_hybridcase_train_dataloader`.
93+
`get_hybridproblem_train_dataloader`.
9494
"""
95-
function get_hybridcase_n_covar(case::AbstractHybridCase; scenario)
96-
train_loader = get_hybridcase_train_dataloader(Random.default_rng(), case; scenario)
95+
function get_hybridproblem_n_covar(prob::AbstractHybridProblem; scenario)
96+
train_loader = get_hybridproblem_train_dataloader(Random.default_rng(), prob; scenario)
9797
(xM, xP, y_o, y_unc) = first(train_loader)
9898
n_covar = size(xM, 1)
9999
return(n_covar)
100100
end
101101

102102
"""
103-
gen_hybridcase_synthetic([rng,] ::AbstractHybridCase; scenario)
103+
gen_hybridcase_synthetic([rng,] ::AbstractHybridProblem; scenario)
104104
105105
Setup synthetic data, a NamedTuple of
106106
- xM: matrix of covariates, with one column per site
@@ -114,40 +114,40 @@ Setup synthetic data, a NamedTuple of
114114
function gen_hybridcase_synthetic end
115115

116116
"""
117-
get_hybridcase_float_type(::AbstractHybridCase; scenario)
117+
get_hybridproblem_float_type(::AbstractHybridProblem; scenario)
118118
119119
Determine the FloatType for given Case and scenario, defaults to Float32
120120
"""
121-
function get_hybridcase_float_type(case::AbstractHybridCase; scenario=())
122-
return eltype(get_hybridcase_par_templates(case; scenario).θM)
121+
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario=())
122+
return eltype(get_hybridproblem_par_templates(prob; scenario).θM)
123123
end
124124

125125
"""
126-
get_hybridcase_train_dataloader([rng,] ::AbstractHybridCase; scenario)
126+
get_hybridproblem_train_dataloader([rng,] ::AbstractHybridProblem; scenario)
127127
128128
Return a DataLoader that provides a tuple of
129129
- `xM`: matrix of covariates, with one column per site
130130
- `xP`: Iterator of process-model drivers, with one element per site
131131
- `y_o`: matrix of observations with added noise, with one column per site
132132
- `y_unc`: matrix `sizeof(y_o)` of uncertainty information
133133
"""
134-
function get_hybridcase_train_dataloader(rng::AbstractRNG, case::AbstractHybridCase;
134+
function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::AbstractHybridProblem;
135135
scenario = ())
136-
(; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, case; scenario)
136+
(; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic(rng, prob; scenario)
137137
n_batch = 10
138138
xM_gpu = :use_Flux scenario ? CuArray(xM) : xM
139139
train_loader = MLUtils.DataLoader((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
140140
return(train_loader)
141141
end
142142

143-
function get_hybridcase_train_dataloader(case::AbstractHybridCase; scenario = ())
143+
function get_hybridproblem_train_dataloader(prob::AbstractHybridProblem; scenario = ())
144144
rng::AbstractRNG = Random.default_rng()
145-
get_hybridcase_train_dataloader(rng, case; scenario)
145+
get_hybridproblem_train_dataloader(rng, prob; scenario)
146146
end
147147

148148

149149
"""
150-
get_hybridcase_cor_starts(case::AbstractHybridCase; scenario)
150+
get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario)
151151
152152
Specify blocks in correlation matrices among parameters.
153153
Returns a NamedTuple.
@@ -163,7 +163,7 @@ then the first subrange starts at position 1 and the second subrange starts at p
163163
If there is only single block of all ML-predicted parameters being correlated
164164
with each other then this block starts at position 1: `(P=(1,3), M=(1,))`.
165165
"""
166-
function get_hybridcase_cor_starts(case::AbstractHybridCase; scenario = ())
166+
function get_hybridproblem_cor_starts(prob::AbstractHybridProblem; scenario = ())
167167
(P=(1,), M=(1,))
168168
end
169169

src/ComponentArrayInterpreter.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ function ComponentArrayInterpreter(
137137
ComponentArrayInterpreter(axes_ext)
138138
end
139139

140-
# ambuiguity with two empty Tuples (edge case that does not make sense)
140+
# ambuiguity with two empty Tuples (edge prob that does not make sense)
141141
# Empty ComponentVector with no other array dimensions -> empty componentVector
142142
function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
143143
ComponentArrayInterpreter(CA.ComponentVector())

src/DoubleMM/f_doubleMM.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct DoubleMMCase <: AbstractHybridCase end
1+
struct DoubleMMCase <: AbstractHybridProblem end
22

33

44
θP = CA.ComponentVector{Float32}(r0 = 0.3, K2 = 2.0)
@@ -18,19 +18,19 @@ function f_doubleMM(θ::AbstractVector, x)
1818
return (y)
1919
end
2020

21-
function HVI.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
21+
function HVI.get_hybridproblem_par_templates(::DoubleMMCase; scenario::NTuple = ())
2222
(; θP, θM)
2323
end
2424

25-
function HVI.get_hybridcase_transforms(::DoubleMMCase; scenario::NTuple = ())
25+
function HVI.get_hybridproblem_transforms(::DoubleMMCase; scenario::NTuple = ())
2626
(; transP, transM)
2727
end
2828

29-
function HVI.get_hybridcase_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ())
29+
function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::NTuple = ())
3030
neg_logden_indep_normal
3131
end
3232

33-
# function HVI.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
33+
# function HVI.get_hybridproblem_sizes(::DoubleMMCase; scenario = ())
3434
# n_covar_pc = 2
3535
# n_covar = n_covar_pc + 3 # linear dependent
3636
# #n_site = 10^n_covar_pc
@@ -41,7 +41,7 @@ end
4141
# (; n_covar, n_batch, n_θM, n_θP)
4242
# end
4343

44-
function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
44+
function HVI.get_hybridproblem_PBmodel(::DoubleMMCase; scenario::NTuple = ())
4545
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
4646
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
4747
pred_sites = applyf(f_doubleMM, θMs, θP, x)
@@ -50,26 +50,26 @@ function HVI.get_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
5050
end
5151
end
5252

53-
# function HVI.get_hybridcase_float_type(::DoubleMMCase; scenario)
53+
# function HVI.get_hybridproblem_float_type(::DoubleMMCase; scenario)
5454
# return Float32
5555
# end
5656

5757
const xP_S1 = Float32[1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.1]
5858
const xP_S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0]
5959

60-
function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase;
60+
function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
6161
scenario = ())
6262
n_covar_pc = 2
6363
n_site = 200
6464
n_covar = 5
6565
n_θM = length(θM)
66-
FloatType = get_hybridcase_float_type(case; scenario)
66+
FloatType = get_hybridproblem_float_type(prob; scenario)
6767
xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
6868
rhodec = 8, is_using_dropout = false)
6969
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
7070
# normalize to be distributed around the prescribed true values
7171
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, FloatType(0.1)))
72-
f = get_hybridcase_PBmodel(case; scenario)
72+
f = get_hybridproblem_PBmodel(prob; scenario)
7373
xP = fill((;S1=xP_S1, S2=xP_S2), n_site)
7474
y_global_true, y_true = f(θP, θMs_true, xP)
7575
σ_o = FloatType(0.01)
@@ -91,10 +91,10 @@ function HVI.gen_hybridcase_synthetic(rng::AbstractRNG, case::DoubleMMCase;
9191
)
9292
end
9393

94-
function HVI.get_hybridcase_MLapplicator(
95-
rng::AbstractRNG, case::HVI.DoubleMM.DoubleMMCase; scenario = ())
94+
function HVI.get_hybridproblem_MLapplicator(
95+
rng::AbstractRNG, prob::HVI.DoubleMM.DoubleMMCase; scenario = ())
9696
ml_engine = select_ml_engine(; scenario)
97-
construct_3layer_MLApplicator(rng, case, ml_engine; scenario)
97+
construct_3layer_MLApplicator(rng, prob, ml_engine; scenario)
9898
end
9999

100100

0 commit comments

Comments
 (0)