diff --git a/pymc/model/core.py b/pymc/model/core.py index 66e633e15..65c116ba9 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -564,15 +564,16 @@ def logp_dlogp_function( for var in self.value_vars if var in input_vars and var not in grad_vars } - return ValueGradFunction( - costs, - grad_vars, - extra_vars_and_values, - model=self, - initial_point=initial_point, - ravel_inputs=ravel_inputs, - **kwargs, - ) + with self: + return ValueGradFunction( + costs, + grad_vars, + extra_vars_and_values, + model=self, + initial_point=initial_point, + ravel_inputs=ravel_inputs, + **kwargs, + ) def compile_logp( self, diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 5dc47fe0e..ab2b554bf 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -223,6 +223,7 @@ def fgraph_from_model( copy_inputs=True, ) # Copy model meta-info to fgraph + fgraph.check_bounds = model.check_bounds fgraph._coords = model._coords.copy() fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()} @@ -318,6 +319,7 @@ def first_non_model_var(var): # TODO: Consider representing/extracting them from the fgraph! _dim_lengths = {k: memo.get(v, v) for k, v in _dim_lengths.items()} + model.check_bounds = getattr(fgraph, "check_bounds", False) model._coords = _coords model._dim_lengths = _dim_lengths diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f1d69c928..b270ab653 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -528,24 +528,30 @@ def join_nonshared_inputs( raise ValueError("Empty list of input variables.") raveled_inputs = pt.concatenate([var.ravel() for var in inputs]) + input_sizes = [point[var_name].size for var_name in point] + size = sum(input_sizes) if not make_inputs_shared: - tensor_type = raveled_inputs.type - joined_inputs = tensor_type("joined_inputs") + joined_inputs = pt.tensor("joined_inputs", shape=(size,), dtype=raveled_inputs.dtype) else: joined_values = np.concatenate([point[var.name].ravel() for var in inputs]) - joined_inputs = pytensor.shared(joined_values, "joined_inputs") + joined_inputs = pytensor.shared(joined_values, "joined_inputs", shape=(size,)) if pytensor.config.compute_test_value != "off": joined_inputs.tag.test_value = raveled_inputs.tag.test_value replace: dict[TensorVariable, TensorVariable] = {} - last_idx = 0 - for var in inputs: + if len(inputs) == 1: + split_vars = [joined_inputs] + else: + split_vars = pt.split(joined_inputs, input_sizes, len(inputs)) + + for var, flat_var in zip(inputs, split_vars, strict=True): shape = point[var.name].shape - arr_len = np.prod(shape, dtype=int) - replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) - last_idx += arr_len + joined_inputs.name == f"{var.name}__flat" + reshaped_var = joined_inputs.reshape(shape) + reshaped_var.name == var.name + replace[var] = reshaped_var if shared_inputs is not None: replace.update(shared_inputs) diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 814cb114d..98ff639f9 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -443,6 +443,15 @@ def test_missing_data(self): # Assert that all the elements of res are equal assert res[1:] == res[:-1] + def test_check_bounds_out_of_model_context(self): + with pm.Model(check_bounds=False) as m: + x = pm.Normal("x") + y = pm.Normal("y", sigma=x) + fn = m.logp_dlogp_function(ravel_inputs=True) + fn.set_extra_values({}) + # When there are no bounds check logp turns into `nan` + assert np.isnan(fn(np.array([-1.0, -1.0]))[0]) + class TestPytensorRelatedLogpBugs: def test_pytensor_switch_broadcast_edge_cases_1(self): diff --git a/tests/model/test_fgraph.py b/tests/model/test_fgraph.py index a3f04e3ce..a06e3949f 100644 --- a/tests/model/test_fgraph.py +++ b/tests/model/test_fgraph.py @@ -397,3 +397,13 @@ def test_multivariate_transform(): new_ip = new_m.initial_point() np.testing.assert_allclose(ip["x_simplex__"], new_ip["x_simplex__"]) np.testing.assert_allclose(ip["y_cholesky-cov-packed__"], new_ip["y_cholesky-cov-packed__"]) + + +def test_check_bounds_preserved(): + with pm.Model(check_bounds=True) as m: + x = pm.HalfNormal("x") + + assert clone_model(m).check_bounds + + m.check_bounds = False + assert not clone_model(m).check_bounds