diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index 5e0037b439..66eb647cca 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -10,10 +10,11 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.configdefaults import config +from pytensor.graph import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python -from pytensor.raise_op import Assert, CheckAndRaise +from pytensor.raise_op import CheckAndRaise if config.floatX == "float64": @@ -73,11 +74,14 @@ def ifelse(cond, *args, n_outs=n_outs): return ifelse -@jax_funcify.register(Assert) @jax_funcify.register(CheckAndRaise) -def jax_funcify_CheckAndRaise(op, **kwargs): +def jax_funcify_CheckAndRaise(op, node, **kwargs): + conds = node.inputs[1:] + if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds): + raise op.exc_type(op.msg) + warnings.warn( - f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""", + f"""Skipping {op} Op (assertion: {op.msg}) as JAX tracing would remove it.""", stacklevel=2, ) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 7eafbdeb3f..62fdd14bae 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -22,6 +22,7 @@ Eye, Join, MakeVector, + ScalarFromTensor, Split, TensorFromScalar, ) @@ -79,6 +80,14 @@ def type_cast(x): return type_cast +@pytorch_funcify.register(ScalarFromTensor) +def pytorch_funcify_ScalarFromTensor(op, node, **kwargs): + def scalar_from_tensor(x): + return x[()] + + return scalar_from_tensor + + @pytorch_funcify.register(CheckAndRaise) def pytorch_funcify_CheckAndRaise(op, **kwargs): error = op.exc_type @@ -86,7 +95,7 @@ def pytorch_funcify_CheckAndRaise(op, **kwargs): def assert_fn(x, *conditions): for cond in conditions: - if not cond.item(): + if not cond: raise error(msg) return x diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index cf951a2527..e23078b8ae 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -2,15 +2,13 @@ from textwrap import indent -import numpy as np - from pytensor.gradient import DisconnectedType -from pytensor.graph.basic import Apply, Variable +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import Generic -from pytensor.scalar.basic import ScalarType +from pytensor.scalar.basic import ScalarType, as_scalar from pytensor.tensor.type import DenseTensorType @@ -56,18 +54,6 @@ def __str__(self): msg = self.msg return f"{name}{{raises={exc_name}, msg='{msg}'}}" - def __eq__(self, other): - if type(self) is not type(other): - return False - - if self.msg == other.msg and self.exc_type == other.exc_type: - return True - - return False - - def __hash__(self): - return hash((self.msg, self.exc_type)) - def make_node(self, value: Variable, *conds: Variable): """ @@ -84,12 +70,10 @@ def make_node(self, value: Variable, *conds: Variable): if not isinstance(value, Variable): value = pt.as_tensor_variable(value) - conds = [ - pt.as_tensor_variable(c) if not isinstance(c, Variable) else c - for c in conds - ] - - assert all(c.type.ndim == 0 for c in conds) + conds = [as_scalar(c) for c in conds] + for i, cond in enumerate(conds): + if cond.dtype != "bool": + conds[i] = cond.astype("bool") return Apply( self, @@ -101,7 +85,7 @@ def perform(self, node, inputs, outputs): (out,) = outputs val, *conds = inputs out[0] = val - if not np.all(conds): + if not all(conds): raise self.exc_type(self.msg) def grad(self, input, output_gradients): @@ -117,38 +101,20 @@ def c_code(self, node, name, inames, onames, props): ) value_name, *cond_names = inames out_name = onames[0] - check = [] fail_code = props["fail"] param_struct_name = props["params"] msg = self.msg.replace('"', '\\"').replace("\n", "\\n") - for idx, cond_name in enumerate(cond_names): - if isinstance(node.inputs[0].type, DenseTensorType): - check.append( - f""" - if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{ - PyObject * exc_type = {param_struct_name}->exc_type; - Py_INCREF(exc_type); - PyErr_SetString(exc_type, "{msg}"); - Py_XDECREF(exc_type); - {indent(fail_code, " " * 4)} - }} - """ - ) - else: - check.append( - f""" - if({cond_name} == 0) {{ - PyObject * exc_type = {param_struct_name}->exc_type; - Py_INCREF(exc_type); - PyErr_SetString(exc_type, "{msg}"); - Py_XDECREF(exc_type); - {indent(fail_code, " " * 4)} - }} - """ - ) - - check = "\n".join(check) + all_conds = " && ".join(cond_names) + check = f""" + if(!({all_conds})) {{ + PyObject * exc_type = {param_struct_name}->exc_type; + Py_INCREF(exc_type); + PyErr_SetString(exc_type, "{msg}"); + Py_XDECREF(exc_type); + {indent(fail_code, " " * 4)} + }} + """ if isinstance(node.inputs[0].type, DenseTensorType): res = f""" @@ -162,14 +128,19 @@ def c_code(self, node, name, inames, onames, props): {check} {out_name} = {value_name}; """ - return res + + return "\n".join((check, res)) def c_code_cache_version(self): - return (1, 1) + return (2,) def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] + def do_constant_folding(self, fgraph, node): + # Only constant-fold if the Assert does not fail + return all((isinstance(c, Constant) and bool(c.data)) for c in node.inputs[1:]) + class Assert(CheckAndRaise): """Implements assertion in a computational graph. diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 9e1c6c1a14..de712b2019 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -732,20 +732,15 @@ def is_an_upcast(type1, type2): @register_useless @register_specialize -@node_rewriter(None) +@node_rewriter([CheckAndRaise]) def local_remove_useless_assert(fgraph, node): - if not isinstance(node.op, CheckAndRaise): - return False - new_conds = [] n_conds = len(node.inputs[1:]) for c in node.inputs[1:]: try: const = get_scalar_constant_value(c) - if 0 != const.ndim or const == 0: - # Should we raise an error here? How to be sure it - # is not caught? + if not const: new_conds.append(c) except NotScalarConstantError: new_conds.append(c) @@ -1106,8 +1101,15 @@ def unconditional_constant_folding(fgraph, node): storage_map[o] = [None] compute_map[o] = [False] - thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) - required = thunk() + try: + thunk = node.op.make_thunk( + node, storage_map, compute_map, no_recycling=[], impl="py" + ) + required = thunk() + except NotImplementedError: + # Not all Ops have a python implementation + thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) + required = thunk() # A node whose inputs are all provided should always return successfully assert not required diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index a959efd6d3..4a78a1e9fe 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -487,8 +487,8 @@ def test_local_remove_useless_1(self): def test_local_remove_useless_2(self): """Remove `CheckAndRaise` conditions that are always true.""" - x = scalar() - y = scalar() + x = scalar("x") + y = ps.bool("y") fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False) fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"]) topo = fg_res.toposort() @@ -497,8 +497,8 @@ def test_local_remove_useless_2(self): def test_local_remove_useless_3(self): """Don't remove `CheckAndRaise` conditions that are always false.""" - x = scalar() - y = scalar() + x = scalar("x") + y = ps.bool("y") fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False) fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"]) topo = fg_res.toposort() @@ -1559,7 +1559,7 @@ def test_local_merge_alloc(): output = pt.alloc(pt.alloc(m, y, 1, 1), x, y2, z, w) f = function([m, x, y, y2, z, w], output, mode=rewrite_mode) topo = f.maker.fgraph.toposort() - assert len(topo) == 3 + assert len(topo) == 4 assert isinstance(topo[-2].op, Assert) assert isinstance(topo[-1].op, Alloc) o = f(0.0, 1, 2, 2, 3, 4) @@ -1616,7 +1616,7 @@ def test_local_useless_alloc(): useless_alloc.rewrite(g) topo = g.toposort() - assert len(topo) == 3 + assert len(topo) == 4 assert isinstance(topo[-2].op, Assert) assert isinstance(topo[-1].op, Alloc) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index e259b7d1a6..c23d0ac23a 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -932,7 +932,7 @@ def large_fuseable_graph(self, n): ), (fx,), (fxv,), - 4, + 5, (np.zeros_like(fxv),), ("float32",), ), diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index e8900ce5d7..dee65c5d76 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -8,6 +8,7 @@ from pytensor import tensor as pt from pytensor.compile.mode import Mode from pytensor.configdefaults import config +from pytensor.graph import rewrite_graph from pytensor.graph.basic import Constant, applys_between, equal_computations from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert @@ -1252,11 +1253,17 @@ def test_broadcast_shape_symbolic_one_symbolic(): ] res_shape = broadcast_shape(*index_shapes, arrays_are_shapes=True) - - from pytensor.graph.rewriting.utils import rewrite_graph - res_shape = rewrite_graph(res_shape) + assert res_shape[0].data == 1 + assert res_shape[1].data == 1 + with pytest.raises(AssertionError, match="Could not broadcast dimensions"): + # broadcast_shape doesn't treat int_div as a constant 1 + res_shape[2].eval() + res_shape = broadcast_shape( + *index_shapes, arrays_are_shapes=True, allow_runtime_broadcast=True + ) + res_shape = rewrite_graph(res_shape) assert res_shape[0].data == 1 assert res_shape[1].data == 1 assert res_shape[2].data == 3 @@ -1294,7 +1301,9 @@ def test_broadcast_arrays(): ["linspace", "logspace", "geomspace"], ids=["linspace", "logspace", "geomspace"], ) -@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"]) +@pytest.mark.parametrize( + "dtype", [None, "int64", "floatX"], ids=[None, "int64", "floatX"] +) @pytest.mark.parametrize( "start, stop, num_samples, endpoint, axis", [ @@ -1310,7 +1319,7 @@ def test_broadcast_arrays(): def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis): pt_func = getattr(pt, op) np_func = getattr(np, op) - dtype = dtype + config.floatX[-2:] if dtype is not None else dtype + dtype = dtype if dtype != "floatX" else config.floatX z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype) numpy_res = np_func( diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 38207d0f5d..9b4b8ebbb9 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1412,30 +1412,41 @@ def _grad_list(self): "uint8", "uint16", "uint32", - pytest.param("uint64", marks=pytest.mark.xfail(reason="Fails due to #770")), + pytest.param( + "uint64", + marks=pytest.mark.xfail( + condition=config.mode != "FAST_COMPILE", reason="Fails due to #770" + ), + ), ), ) def test_uint(self, dtype): itype = np.iinfo(dtype) - data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) - n = as_tensor_variable(data) + data = np.array( + [itype.min + 3, itype.min, itype.max - 5, itype.max], dtype=dtype + ) + n = vector("n", shape=(None,), dtype=dtype) - assert min(n).dtype == dtype - i_min = eval_outputs(min(n)) + min_out = min(n) + assert min_out.dtype == dtype + i_min = function([n], min_out)(data) assert i_min == itype.min - assert max(n).dtype == dtype - i_max = eval_outputs(max(n)) + max_out = max(n) + assert max_out.dtype == dtype + i_max = function([n], max_out)(data) assert i_max == itype.max - @pytest.mark.xfail(reason="Fails due to #770") + @pytest.mark.xfail( + condition=config.mode != "FAST_COMPILE", reason="Fails due to #770" + ) def test_uint64_special_value(self): """Example from issue #770""" dtype = "uint64" data = np.array([0, 9223372036854775], dtype=dtype) - n = as_tensor_variable(data) + n = vector("n", shape=(None,), dtype=dtype) - i_max = eval_outputs(max(n)) + i_max = function([n], max(n))(data) assert i_max == data.max() def test_bool(self): diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index 7d10f760d9..9ba6040418 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -82,19 +82,26 @@ def test_CheckAndRaise_basic_c(linker): with pytest.raises(CustomException, match=exc_msg): y_fn(0) + assert y_fn(1) == 1.0 x = pt.vector() + x_val = np.array([1.0], dtype=pytensor.config.floatX) + y = check_and_raise(x, conds) - y_fn = pytensor.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN)) + y_fn = pytensor.function([conds, x], y, mode=Mode(linker, OPT_FAST_RUN)) + with pytest.raises(CustomException, match=exc_msg): + y_fn(0, x_val) + assert np.array_equal(y_fn(1, x_val), x_val) - x_val = np.array([1.0], dtype=pytensor.config.floatX) + y_fn = pytensor.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN)) + # The shape doesn't depend on y so the Assert is dropped from the graph assert np.array_equal(y_fn(0, x_val), x_val) y = check_and_raise(x, pt.as_tensor(0)) - y_grad = pytensor.grad(y.sum(), [x]) + y_grad = pytensor.grad(y.sum(), x) y_fn = pytensor.function([x], y_grad, mode=Mode(linker, OPT_FAST_RUN)) - - assert np.array_equal(y_fn(x_val), [x_val]) + # The gradient doesn't depend on y, just it's shape so the Assert is dropped from the graph + assert np.array_equal(y_fn(x_val), x_val) @pytest.mark.parametrize(