@@ -58,6 +58,22 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
58
58
return eigvec @ np .diag (eigval ) @ eigvec .T
59
59
60
60
61
+ def unobserved_value_vars (model ):
62
+ vars = []
63
+ transformed_rvs = []
64
+ for rv in model .free_RVs :
65
+ value_var = model .rvs_to_values [rv ]
66
+ transform = model .rvs_to_transforms [rv ]
67
+ if transform is not None :
68
+ transformed_rvs .append (rv )
69
+ vars .append (value_var )
70
+
71
+ # Remove rvs from untransformed values graph
72
+ untransformed_vars = model .replace_rvs_by_values (transformed_rvs )
73
+
74
+ return vars + untransformed_vars
75
+
76
+
61
77
def _unconstrained_vector_to_constrained_rvs (model ):
62
78
constrained_rvs , unconstrained_vector = join_nonshared_inputs (
63
79
model .initial_point (), inputs = model .value_vars , outputs = model .unobserved_value_vars
@@ -133,7 +149,9 @@ def jax_fit_mvn_to_MAP(
133
149
logp = frozen_model .logp (jacobian = True )
134
150
variables = frozen_model .continuous_value_vars
135
151
136
- mu = DictToArrayBijection .map (optimized_point )
152
+ variable_names = {var .name for var in variables }
153
+ optimized_free_params = {k : v for k , v in optimized_point .items () if k in variable_names }
154
+ mu = DictToArrayBijection .map (optimized_free_params )
137
155
138
156
[neg_logp ], flat_inputs = join_nonshared_inputs (
139
157
point = frozen_model .initial_point (), outputs = [- logp ], inputs = variables
0 commit comments