@@ -7,7 +7,7 @@ struct HybridProblem <: AbstractHybridProblem
77 py
88 transP
99 transM
10- cor_starts # = (P=(1,),M=(1,))
10+ cor_ends # = (P=(1,),M=(1,))
1111 get_train_loader
1212 # inner constructor to constrain the types
1313 function HybridProblem (
@@ -20,8 +20,8 @@ struct HybridProblem <: AbstractHybridProblem
2020 # train_loader::DataLoader,
2121 # return a function that constructs the trainloader based on n_batch
2222 get_train_loader:: Function ,
23- cor_starts :: NamedTuple = (P = ( 1 ,) , M = ( 1 ,) ))
24- new (θP, θM, f, g, ϕg, py, transM, transP, cor_starts , get_train_loader)
23+ cor_ends :: NamedTuple = (P = [ length (θP)] , M = [ length (θM)] ))
24+ new (θP, θM, f, g, ϕg, py, transM, transP, cor_ends , get_train_loader)
2525 end
2626end
2727
@@ -45,8 +45,8 @@ function HybridProblem(prob::AbstractHybridProblem; scenario = ())
4545 get_hybridproblem_train_dataloader (rng:: AbstractRNG , prob; scenario, kwargs... )
4646 end
4747 end
48- cor_starts = get_hybridproblem_cor_starts (prob; scenario)
49- HybridProblem (θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts )
48+ cor_ends = get_hybridproblem_cor_ends (prob; scenario)
49+ HybridProblem (θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_ends )
5050end
5151
5252function update (prob:: HybridProblem ;
@@ -58,7 +58,7 @@ function update(prob::HybridProblem;
5858 transM:: Union{Function, Bijectors.Transform} = prob. transM,
5959 transP:: Union{Function, Bijectors.Transform} = prob. transP,
6060 get_train_loader:: Function = prob. get_train_loader,
61- cor_starts :: NamedTuple = prob. cor_starts )
61+ cor_ends :: NamedTuple = prob. cor_ends )
6262 # prob.θP = θP
6363 # prob.θM = θM
6464 # prob.f = f
@@ -67,15 +67,19 @@ function update(prob::HybridProblem;
6767 # prob.py = py
6868 # prob.transM = transM
6969 # prob.transP = transP
70- # prob.cor_starts = cor_starts
70+ # prob.cor_ends = cor_ends
7171 # prob.get_train_loader = get_train_loader
72- HybridProblem (θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_starts )
72+ HybridProblem (θP, θM, g, ϕg, f, py, transP, transM, get_train_loader, cor_ends )
7373end
7474
7575function get_hybridproblem_par_templates (prob:: HybridProblem ; scenario:: NTuple = ())
7676 (; θP = prob. θP, θM = prob. θM)
7777end
7878
79+ function get_hybridproblem_ϕunc (prob:: HybridProblem ; scenario:: NTuple = ())
80+ prob. ϕunc
81+ end
82+
7983function get_hybridproblem_neg_logden_obs (prob:: HybridProblem ; scenario:: NTuple = ())
8084 prob. py
8185end
@@ -102,8 +106,8 @@ function get_hybridproblem_train_dataloader(rng::AbstractRNG, prob::HybridProble
102106 return prob. get_train_loader (rng; kwargs... )
103107end
104108
105- function get_hybridproblem_cor_starts (prob:: HybridProblem ; scenario = ())
106- prob. cor_starts
109+ function get_hybridproblem_cor_ends (prob:: HybridProblem ; scenario = ())
110+ prob. cor_ends
107111end
108112
109113# function get_hybridproblem_float_type(prob::HybridProblem; scenario::NTuple = ())
0 commit comments