@@ -8,10 +8,10 @@ For a specific case, provide functions that specify details
88- `get_hybridcase_neg_logden_obs`
99- `get_hybridcase_par_templates`
1010- `get_hybridcase_transforms`
11- - `get_hybridcase_sizes`
1211- `get_hybridcase_train_dataloader` (default depends on `gen_hybridcase_synthetic`)
1312optionally
1413- `gen_hybridcase_synthetic`
14+ - `get_hybridcase_n_covar` (defaults to number of rows in xM in train_dataloader )
1515- `get_hybridcase_float_type` (defaults to `eltype(θM)`)
1616- `get_hybridcase_cor_starts` (defaults to include all correlations: `(P=(1,), M=(1,))`)
1717"""
@@ -79,19 +79,31 @@ Return a NamedTupe of
7979"""
8080function get_hybridcase_transforms end
8181
82+ # """
83+ # get_hybridcase_par_templates(::AbstractHybridCase; scenario)
84+ # Provide a NamedTuple of number of
85+ # - n_covar: covariates xM
86+ # - n_site: all sites in the data
87+ # - n_batch: sites in one minibatch during fitting
88+ # - n_θM, n_θP: entries in parameter vectors
89+ # """
90+ # function get_hybridcase_sizes end
91+
8292"""
83- get_hybridcase_par_templates (::AbstractHybridCase; scenario)
93+ get_hybridcase_n_covar (::AbstractHybridCase; scenario)
8494
85- Provide a NamedTuple of number of
86- - n_covar: covariates xM
87- - n_site: all sites in the data
88- - n_batch: sites in one minibatch during fitting
89- - n_θM, n_θP: entries in parameter vectors
95+ Provide the number of covariates. Default returns the number of rows in `xM` from
96+ `get_hybridcase_train_dataloader`.
9097"""
91- function get_hybridcase_sizes end
98+ function get_hybridcase_n_covar (case:: AbstractHybridCase ; scenario)
99+ train_loader = get_hybridcase_train_dataloader (Random. default_rng (), case; scenario)
100+ (xM, xP, y_o, y_unc) = first (train_loader)
101+ n_covar = size (xM, 1 )
102+ return (n_covar)
103+ end
92104
93105"""
94- gen_hybridcase_synthetic(::AbstractHybridCase, rng ; scenario)
106+ gen_hybridcase_synthetic([rng,] ::AbstractHybridCase; scenario)
95107
96108Setup synthetic data, a NamedTuple of
97109- xM: matrix of covariates, with one column per site
@@ -114,23 +126,29 @@ function get_hybridcase_float_type(case::AbstractHybridCase; scenario=())
114126end
115127
116128"""
117- get_hybridcase_train_dataloader(::AbstractHybridCase, rng ; scenario)
129+ get_hybridcase_train_dataloader([rng,] ::AbstractHybridCase; scenario)
118130
119131Return a DataLoader that provides a tuple of
120132- `xM`: matrix of covariates, with one column per site
121133- `xP`: Iterator of process-model drivers, with one element per site
122134- `y_o`: matrix of observations with added noise, with one column per site
123135- `y_unc`: matrix `sizeof(y_o)` of uncertainty information
124136"""
125- function get_hybridcase_train_dataloader (case :: AbstractHybridCase , rng :: AbstractRNG ;
137+ function get_hybridcase_train_dataloader (rng :: AbstractRNG , case :: AbstractHybridCase ;
126138 scenario = ())
127- (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic (case, rng ; scenario)
128- (; n_batch) = get_hybridcase_sizes (case; scenario)
139+ (; xM, xP, y_o, y_unc) = gen_hybridcase_synthetic (rng, case ; scenario)
140+ n_batch = 10
129141 xM_gpu = :use_flux ∈ scenario ? CuArray (xM) : xM
130142 train_loader = MLUtils. DataLoader ((xM_gpu, xP, y_o, y_unc), batchsize = n_batch)
131143 return (train_loader)
132144end
133145
146+ function get_hybridcase_train_dataloader (case:: AbstractHybridCase ; scenario = ())
147+ rng:: AbstractRNG = Random. default_rng ()
148+ get_hybridcase_train_dataloader (rng, case; scenario)
149+ end
150+
151+
134152"""
135153 get_hybridcase_cor_starts(case::AbstractHybridCase; scenario)
136154
0 commit comments