Skip to content

Commit a23762b

Browse files
More refactor
1 parent f705d43 commit a23762b

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

pymc_experimental/inference/jax_find_map.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,22 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
5858
return eigvec @ np.diag(eigval) @ eigvec.T
5959

6060

61+
def unobserved_value_vars(model):
62+
vars = []
63+
transformed_rvs = []
64+
for rv in model.free_RVs:
65+
value_var = model.rvs_to_values[rv]
66+
transform = model.rvs_to_transforms[rv]
67+
if transform is not None:
68+
transformed_rvs.append(rv)
69+
vars.append(value_var)
70+
71+
# Remove rvs from untransformed values graph
72+
untransformed_vars = model.replace_rvs_by_values(transformed_rvs)
73+
74+
return vars + untransformed_vars
75+
76+
6177
def _unconstrained_vector_to_constrained_rvs(model):
6278
constrained_rvs, unconstrained_vector = join_nonshared_inputs(
6379
model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars
@@ -133,7 +149,9 @@ def jax_fit_mvn_to_MAP(
133149
logp = frozen_model.logp(jacobian=True)
134150
variables = frozen_model.continuous_value_vars
135151

136-
mu = DictToArrayBijection.map(optimized_point)
152+
variable_names = {var.name for var in variables}
153+
optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names}
154+
mu = DictToArrayBijection.map(optimized_free_params)
137155

138156
[neg_logp], flat_inputs = join_nonshared_inputs(
139157
point=frozen_model.initial_point(), outputs=[-logp], inputs=variables

0 commit comments

Comments
 (0)