Skip to content

Commit b823479

Browse files
committed
Use model context in logp_dlogp_function to respect check_bounds
1 parent b91c4a0 commit b823479

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

pymc/model/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -564,15 +564,16 @@ def logp_dlogp_function(
564564
for var in self.value_vars
565565
if var in input_vars and var not in grad_vars
566566
}
567-
return ValueGradFunction(
568-
costs,
569-
grad_vars,
570-
extra_vars_and_values,
571-
model=self,
572-
initial_point=initial_point,
573-
ravel_inputs=ravel_inputs,
574-
**kwargs,
575-
)
567+
with self:
568+
return ValueGradFunction(
569+
costs,
570+
grad_vars,
571+
extra_vars_and_values,
572+
model=self,
573+
initial_point=initial_point,
574+
ravel_inputs=ravel_inputs,
575+
**kwargs,
576+
)
576577

577578
def compile_logp(
578579
self,

tests/model/test_core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,15 @@ def test_missing_data(self):
443443
# Assert that all the elements of res are equal
444444
assert res[1:] == res[:-1]
445445

446+
def test_check_bounds_out_of_model_context(self):
447+
with pm.Model(check_bounds=False) as m:
448+
x = pm.Normal("x")
449+
y = pm.Normal("y", sigma=x)
450+
fn = m.logp_dlogp_function(ravel_inputs=True)
451+
fn.set_extra_values({})
452+
# When there are no bounds check logp turns into `nan`
453+
assert np.isnan(fn(np.array([-1.0, -1.0]))[0])
454+
446455

447456
class TestPytensorRelatedLogpBugs:
448457
def test_pytensor_switch_broadcast_edge_cases_1(self):

0 commit comments

Comments
 (0)