Skip to content

Commit 4057b07

Browse files
twieckiclaude
andcommitted
Fix API compatibility issues for PyMC v5
Fixed multiple API changes: 1. numpy.AxisError → numpy.exceptions.AxisError (NumPy 2.x) 2. pm.joint_logpt → pm.logprob.conditional_logp (PyMC v5) 3. pm.pytensorf.compile_pymc → pm.pytensorf.compile (PyMC v5) All test_utils.py tests now passing (32/32) ✓ 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 310cc2d commit 4057b07

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

homepy/tests/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ def test_errors(self):
9898
dims = ("a", "b")
9999
coords = {"a": range(5), "b": range(1)}
100100
with pm.Model(coords=coords) as m:
101-
with pytest.raises(np.AxisError):
101+
with pytest.raises(np.exceptions.AxisError):
102102
CenteredNormal("x", 1.0, dims=dims, axis=2)
103103

104-
with pytest.raises(np.AxisError):
104+
with pytest.raises(np.exceptions.AxisError):
105105
CenteredNormal("x", 1.0, dims=dims, axis=-3)
106106

107107
with pytest.raises(ValueError, match="Axis size must be larger than 1"):

homepy/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,16 @@ def compute_scalar_log_likelihood(
132132
for observed in model.observed_RVs:
133133
rv_values[observed] = model.rvs_to_values[observed]
134134

135-
log_like_vars = pm.joint_logpt(list(rv_values.keys()), rv_values, transformed=False, sum=False)
136-
if not isinstance(log_like_vars, list):
137-
log_like_vars = [log_like_vars]
138-
log_like_vars = log_like_vars[len(free_RVs) :]
135+
# Get conditional logps for all variables
136+
logp_dict = pm.logprob.conditional_logp(rv_values)
137+
138+
# Extract only the observed RVs' logps
139+
# conditional_logp returns a dict keyed by the value variables, not the RVs
140+
log_like_vars = [logp_dict[model.rvs_to_values[observed]] for observed in model.observed_RVs]
139141

140142
log_like = pt.sum([pt.sum(log_like_var) for log_like_var in log_like_vars])
141143

142-
log_like_fn = pm.pytensorf.compile_pymc(
144+
log_like_fn = pm.pytensorf.compile(
143145
inputs=list(rv_values.values())[: len(free_RVs)],
144146
outputs=log_like,
145147
on_unused_input="ignore",
@@ -168,9 +170,11 @@ def get_model_logp_function(model):
168170
for observed in model.observed_RVs:
169171
rv_values[observed] = model.rvs_to_values[observed]
170172

171-
logp = pm.joint_logpt(model.logp(), rv_values, transformed=False, sum=True)
173+
# Get conditional logps and sum them for joint logp
174+
logp_dict = pm.logprob.conditional_logp(rv_values)
175+
logp = pt.sum([pt.sum(lp) for lp in logp_dict.values()])
172176

173-
logp_fn = pm.pytensorf.compile_pymc(
177+
logp_fn = pm.pytensorf.compile(
174178
inputs=list(rv_values.values())[: len(free_RVs)],
175179
outputs=logp,
176180
on_unused_input="ignore",

0 commit comments

Comments
 (0)