@@ -58,6 +58,22 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
5858 return eigvec @ np .diag (eigval ) @ eigvec .T
5959
6060
61+ def unobserved_value_vars (model ):
62+ vars = []
63+ transformed_rvs = []
64+ for rv in model .free_RVs :
65+ value_var = model .rvs_to_values [rv ]
66+ transform = model .rvs_to_transforms [rv ]
67+ if transform is not None :
68+ transformed_rvs .append (rv )
69+ vars .append (value_var )
70+
71+ # Remove rvs from untransformed values graph
72+ untransformed_vars = model .replace_rvs_by_values (transformed_rvs )
73+
74+ return vars + untransformed_vars
75+
76+
6177def _unconstrained_vector_to_constrained_rvs (model ):
6278 constrained_rvs , unconstrained_vector = join_nonshared_inputs (
6379 model .initial_point (), inputs = model .value_vars , outputs = model .unobserved_value_vars
@@ -133,7 +149,9 @@ def jax_fit_mvn_to_MAP(
133149 logp = frozen_model .logp (jacobian = True )
134150 variables = frozen_model .continuous_value_vars
135151
136- mu = DictToArrayBijection .map (optimized_point )
152+ variable_names = {var .name for var in variables }
153+ optimized_free_params = {k : v for k , v in optimized_point .items () if k in variable_names }
154+ mu = DictToArrayBijection .map (optimized_free_params )
137155
138156 [neg_logp ], flat_inputs = join_nonshared_inputs (
139157 point = frozen_model .initial_point (), outputs = [- logp ], inputs = variables
0 commit comments