@@ -33,7 +33,7 @@ def fit_dadvi(
33
33
n_fixed_draws : int = 30 ,
34
34
random_seed : RandomSeed = None ,
35
35
n_draws : int = 1000 ,
36
- keep_untransformed : bool = False ,
36
+ include_transformed : bool = False ,
37
37
optimizer_method : minimize_method = "trust-ncg" ,
38
38
use_grad : bool | None = None ,
39
39
use_hessp : bool | None = None ,
@@ -63,7 +63,7 @@ def fit_dadvi(
63
63
n_draws: int
64
64
The number of draws to return from the variational approximation.
65
65
66
- keep_untransformed : bool
66
+ include_transformed : bool
67
67
Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
68
68
output.
69
69
@@ -166,9 +166,7 @@ def fit_dadvi(
166
166
draws = opt_means + draws_raw * np .exp (opt_log_sds )
167
167
draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
168
168
169
- idata = az .InferenceData (
170
- posterior = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
171
- )
169
+ idata = dadvi_result_to_idata (draws_arviz , model , include_transformed = include_transformed )
172
170
173
171
var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
174
172
var_name_to_model_var .update (
@@ -251,10 +249,10 @@ def create_dadvi_graph(
251
249
return var_params , objective
252
250
253
251
254
- def transform_draws (
252
+ def dadvi_result_to_idata (
255
253
unstacked_draws : xarray .Dataset ,
256
254
model : Model ,
257
- keep_untransformed : bool = False ,
255
+ include_transformed : bool = False ,
258
256
):
259
257
"""
260
258
Transforms the unconstrained draws back into the constrained space.
@@ -270,7 +268,7 @@ def transform_draws(
270
268
n_draws: int
271
269
The number of draws to return from the variational approximation.
272
270
273
- keep_untransformed : bool
271
+ include_transformed : bool
274
272
Whether or not to keep the unconstrained variables in the output.
275
273
276
274
Returns
@@ -281,7 +279,7 @@ def transform_draws(
281
279
282
280
filtered_var_names = model .unobserved_value_vars
283
281
vars_to_sample = list (
284
- get_default_varnames (filtered_var_names , include_transformed = keep_untransformed )
282
+ get_default_varnames (filtered_var_names , include_transformed = include_transformed )
285
283
)
286
284
fn = pytensor .function (model .value_vars , vars_to_sample )
287
285
point_func = PointFunc (fn )
@@ -296,4 +294,17 @@ def transform_draws(
296
294
dims = dims ,
297
295
)
298
296
299
- return transformed_result
297
+ constrained_names = [
298
+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = False )
299
+ ]
300
+ all_varnames = [
301
+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = True )
302
+ ]
303
+ unconstrained_names = set (all_varnames ) - set (constrained_names )
304
+
305
+ idata = az .InferenceData (posterior = transformed_result [constrained_names ])
306
+
307
+ if unconstrained_names and include_transformed :
308
+ idata ["unconstrained_posterior" ] = transformed_result [unconstrained_names ]
309
+
310
+ return idata
0 commit comments