Skip to content

Commit b060d01

Browse files
fix unpickling PointFunc bug and add regression test
1 parent ae43026 commit b060d01

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

pymc/pytensorf.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,15 @@ def __call__(self, state):
569569

570570
def __getattr__(self, item):
571571
"""Allow access to the original function attributes."""
572-
# This is only reached if `__getattribute__` fails.
573-
return getattr(self.f, item)
572+
# During unpickling, attribute access will occur before self.f is set.
573+
# This leads to an infinite loop, trying to call __getattr__ on self.f.
574+
# To break this loop, first check if self.f is set, and raise if not.
575+
try:
576+
f = object.__getattribute__(self, "f")
577+
except AttributeError:
578+
raise AttributeError(item)
579+
580+
return getattr(f, item)
574581

575582

576583
class CallableTensor:

tests/step_methods/test_metropolis.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,20 @@ def test_sampling_state(step_method, model_fn):
425425
assert equal_sampling_states(final_state1, final_state2)
426426
assert equal_dataclass_values(sample1, sample2)
427427
assert equal_dataclass_values(stat1, stat2)
428+
429+
430+
def test_binary_gibbs_with_spawn():
431+
# Regression test for https://github.com/pymc-devs/pymc/issues/7857
432+
433+
with pm.Model():
434+
x = pm.Categorical("x", logit_p=[1.0, 1.0, 1.0, 1.0])
435+
idata = pm.sample(
436+
step=pm.CategoricalGibbsMetropolis([x]),
437+
mp_ctx="spawn",
438+
compute_convergence_checks=False,
439+
chains=2,
440+
tune=0,
441+
draws=50,
442+
)
443+
444+
assert idata

0 commit comments

Comments
 (0)