Skip to content

Commit 6cd90ee

Browse files
committed
Fix RaiseAndCheck C implementation with tensor conditions.
For performance, the Op now always converts the inputs to boolean scalars. Also do not constant-fold if it would raise.
1 parent b9fc4f8 commit 6cd90ee

File tree

8 files changed

+73
-80
lines changed

8 files changed

+73
-80
lines changed

pytensor/link/jax/dispatch/basic.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from pytensor.compile.builders import OpFromGraph
1111
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
1212
from pytensor.configdefaults import config
13+
from pytensor.graph import Constant
1314
from pytensor.graph.fg import FunctionGraph
1415
from pytensor.ifelse import IfElse
1516
from pytensor.link.utils import fgraph_to_python
16-
from pytensor.raise_op import Assert, CheckAndRaise
17+
from pytensor.raise_op import CheckAndRaise
1718

1819

1920
if config.floatX == "float64":
@@ -73,11 +74,14 @@ def ifelse(cond, *args, n_outs=n_outs):
7374
return ifelse
7475

7576

76-
@jax_funcify.register(Assert)
7777
@jax_funcify.register(CheckAndRaise)
78-
def jax_funcify_CheckAndRaise(op, **kwargs):
78+
def jax_funcify_CheckAndRaise(op, node, **kwargs):
79+
conds = node.inputs[1:]
80+
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
81+
raise op.exc_type(op.msg)
82+
7983
warnings.warn(
80-
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""",
84+
f"""Skipping {op} Op (assertion: {op.msg}) as JAX tracing would remove it.""",
8185
stacklevel=2,
8286
)
8387

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Eye,
2323
Join,
2424
MakeVector,
25+
ScalarFromTensor,
2526
Split,
2627
TensorFromScalar,
2728
)
@@ -79,14 +80,22 @@ def type_cast(x):
7980
return type_cast
8081

8182

83+
@pytorch_funcify.register(ScalarFromTensor)
84+
def pytorch_funcify_ScalarFromTensor(op, node, **kwargs):
85+
def scalar_from_tensor(x):
86+
return x[()]
87+
88+
return scalar_from_tensor
89+
90+
8291
@pytorch_funcify.register(CheckAndRaise)
8392
def pytorch_funcify_CheckAndRaise(op, **kwargs):
8493
error = op.exc_type
8594
msg = op.msg
8695

8796
def assert_fn(x, *conditions):
8897
for cond in conditions:
89-
if not cond.item():
98+
if not cond:
9099
raise error(msg)
91100
return x
92101

pytensor/raise_op.py

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22

33
from textwrap import indent
44

5-
import numpy as np
6-
75
from pytensor.gradient import DisconnectedType
8-
from pytensor.graph.basic import Apply, Variable
6+
from pytensor.graph.basic import Apply, Constant, Variable
97
from pytensor.graph.replace import _vectorize_node
108
from pytensor.link.c.op import COp
119
from pytensor.link.c.params_type import ParamsType
1210
from pytensor.link.c.type import Generic
13-
from pytensor.scalar.basic import ScalarType
11+
from pytensor.scalar.basic import ScalarType, as_scalar
1412
from pytensor.tensor.type import DenseTensorType
1513

1614

@@ -56,18 +54,6 @@ def __str__(self):
5654
msg = self.msg
5755
return f"{name}{{raises={exc_name}, msg='{msg}'}}"
5856

59-
def __eq__(self, other):
60-
if type(self) is not type(other):
61-
return False
62-
63-
if self.msg == other.msg and self.exc_type == other.exc_type:
64-
return True
65-
66-
return False
67-
68-
def __hash__(self):
69-
return hash((self.msg, self.exc_type))
70-
7157
def make_node(self, value: Variable, *conds: Variable):
7258
"""
7359
@@ -84,12 +70,10 @@ def make_node(self, value: Variable, *conds: Variable):
8470
if not isinstance(value, Variable):
8571
value = pt.as_tensor_variable(value)
8672

87-
conds = [
88-
pt.as_tensor_variable(c) if not isinstance(c, Variable) else c
89-
for c in conds
90-
]
91-
92-
assert all(c.type.ndim == 0 for c in conds)
73+
conds = [as_scalar(c) for c in conds]
74+
for i, cond in enumerate(conds):
75+
if cond.dtype != "bool":
76+
conds[i] = cond.astype("bool")
9377

9478
return Apply(
9579
self,
@@ -101,7 +85,7 @@ def perform(self, node, inputs, outputs):
10185
(out,) = outputs
10286
val, *conds = inputs
10387
out[0] = val
104-
if not np.all(conds):
88+
if not all(conds):
10589
raise self.exc_type(self.msg)
10690

10791
def grad(self, input, output_gradients):
@@ -117,38 +101,20 @@ def c_code(self, node, name, inames, onames, props):
117101
)
118102
value_name, *cond_names = inames
119103
out_name = onames[0]
120-
check = []
121104
fail_code = props["fail"]
122105
param_struct_name = props["params"]
123106
msg = self.msg.replace('"', '\\"').replace("\n", "\\n")
124107

125-
for idx, cond_name in enumerate(cond_names):
126-
if isinstance(node.inputs[0].type, DenseTensorType):
127-
check.append(
128-
f"""
129-
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
130-
PyObject * exc_type = {param_struct_name}->exc_type;
131-
Py_INCREF(exc_type);
132-
PyErr_SetString(exc_type, "{msg}");
133-
Py_XDECREF(exc_type);
134-
{indent(fail_code, " " * 4)}
135-
}}
136-
"""
137-
)
138-
else:
139-
check.append(
140-
f"""
141-
if({cond_name} == 0) {{
142-
PyObject * exc_type = {param_struct_name}->exc_type;
143-
Py_INCREF(exc_type);
144-
PyErr_SetString(exc_type, "{msg}");
145-
Py_XDECREF(exc_type);
146-
{indent(fail_code, " " * 4)}
147-
}}
148-
"""
149-
)
150-
151-
check = "\n".join(check)
108+
all_conds = " && ".join(cond_names)
109+
check = f"""
110+
if(!({all_conds})) {{
111+
PyObject * exc_type = {param_struct_name}->exc_type;
112+
Py_INCREF(exc_type);
113+
PyErr_SetString(exc_type, "{msg}");
114+
Py_XDECREF(exc_type);
115+
{indent(fail_code, " " * 4)}
116+
}}
117+
"""
152118

153119
if isinstance(node.inputs[0].type, DenseTensorType):
154120
res = f"""
@@ -162,14 +128,19 @@ def c_code(self, node, name, inames, onames, props):
162128
{check}
163129
{out_name} = {value_name};
164130
"""
165-
return res
131+
132+
return "\n".join((check, res))
166133

167134
def c_code_cache_version(self):
168-
return (1, 1)
135+
return (2,)
169136

170137
def infer_shape(self, fgraph, node, input_shapes):
171138
return [input_shapes[0]]
172139

140+
def do_constant_folding(self, fgraph, node):
141+
# Only constant-fold if the Assert does not fail
142+
return all((isinstance(c, Constant) and bool(c.data)) for c in node.inputs[1:])
143+
173144

174145
class Assert(CheckAndRaise):
175146
"""Implements assertion in a computational graph.

pytensor/tensor/rewriting/basic.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -732,20 +732,15 @@ def is_an_upcast(type1, type2):
732732

733733
@register_useless
734734
@register_specialize
735-
@node_rewriter(None)
735+
@node_rewriter([CheckAndRaise])
736736
def local_remove_useless_assert(fgraph, node):
737-
if not isinstance(node.op, CheckAndRaise):
738-
return False
739-
740737
new_conds = []
741738
n_conds = len(node.inputs[1:])
742739
for c in node.inputs[1:]:
743740
try:
744741
const = get_scalar_constant_value(c)
745742

746-
if 0 != const.ndim or const == 0:
747-
# Should we raise an error here? How to be sure it
748-
# is not caught?
743+
if not const:
749744
new_conds.append(c)
750745
except NotScalarConstantError:
751746
new_conds.append(c)

tests/tensor/rewriting/test_basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,8 @@ def test_local_remove_useless_1(self):
487487

488488
def test_local_remove_useless_2(self):
489489
"""Remove `CheckAndRaise` conditions that are always true."""
490-
x = scalar()
491-
y = scalar()
490+
x = scalar("x")
491+
y = ps.bool("y")
492492
fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False)
493493
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
494494
topo = fg_res.toposort()
@@ -497,8 +497,8 @@ def test_local_remove_useless_2(self):
497497

498498
def test_local_remove_useless_3(self):
499499
"""Don't remove `CheckAndRaise` conditions that are always false."""
500-
x = scalar()
501-
y = scalar()
500+
x = scalar("x")
501+
y = ps.bool("y")
502502
fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False)
503503
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
504504
topo = fg_res.toposort()
@@ -1559,7 +1559,7 @@ def test_local_merge_alloc():
15591559
output = pt.alloc(pt.alloc(m, y, 1, 1), x, y2, z, w)
15601560
f = function([m, x, y, y2, z, w], output, mode=rewrite_mode)
15611561
topo = f.maker.fgraph.toposort()
1562-
assert len(topo) == 3
1562+
assert len(topo) == 4
15631563
assert isinstance(topo[-2].op, Assert)
15641564
assert isinstance(topo[-1].op, Alloc)
15651565
o = f(0.0, 1, 2, 2, 3, 4)
@@ -1616,7 +1616,7 @@ def test_local_useless_alloc():
16161616
useless_alloc.rewrite(g)
16171617

16181618
topo = g.toposort()
1619-
assert len(topo) == 3
1619+
assert len(topo) == 4
16201620
assert isinstance(topo[-2].op, Assert)
16211621
assert isinstance(topo[-1].op, Alloc)
16221622

tests/tensor/rewriting/test_elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ def large_fuseable_graph(self, n):
932932
),
933933
(fx,),
934934
(fxv,),
935-
4,
935+
5,
936936
(np.zeros_like(fxv),),
937937
("float32",),
938938
),

tests/tensor/test_extra_ops.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor import tensor as pt
99
from pytensor.compile.mode import Mode
1010
from pytensor.configdefaults import config
11+
from pytensor.graph import rewrite_graph
1112
from pytensor.graph.basic import Constant, applys_between, equal_computations
1213
from pytensor.npy_2_compat import old_np_unique
1314
from pytensor.raise_op import Assert
@@ -1252,11 +1253,17 @@ def test_broadcast_shape_symbolic_one_symbolic():
12521253
]
12531254

12541255
res_shape = broadcast_shape(*index_shapes, arrays_are_shapes=True)
1255-
1256-
from pytensor.graph.rewriting.utils import rewrite_graph
1257-
12581256
res_shape = rewrite_graph(res_shape)
1257+
assert res_shape[0].data == 1
1258+
assert res_shape[1].data == 1
1259+
with pytest.raises(AssertionError, match="Could not broadcast dimensions"):
1260+
# broadcast_shape doesn't treat int_div as a constant 1
1261+
res_shape[2].eval()
12591262

1263+
res_shape = broadcast_shape(
1264+
*index_shapes, arrays_are_shapes=True, allow_runtime_broadcast=True
1265+
)
1266+
res_shape = rewrite_graph(res_shape)
12601267
assert res_shape[0].data == 1
12611268
assert res_shape[1].data == 1
12621269
assert res_shape[2].data == 3

tests/test_raise_op.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,26 @@ def test_CheckAndRaise_basic_c(linker):
8282

8383
with pytest.raises(CustomException, match=exc_msg):
8484
y_fn(0)
85+
assert y_fn(1) == 1.0
8586

8687
x = pt.vector()
88+
x_val = np.array([1.0], dtype=pytensor.config.floatX)
89+
8790
y = check_and_raise(x, conds)
88-
y_fn = pytensor.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN))
91+
y_fn = pytensor.function([conds, x], y, mode=Mode(linker, OPT_FAST_RUN))
92+
with pytest.raises(CustomException, match=exc_msg):
93+
y_fn(0, x_val)
94+
assert np.array_equal(y_fn(1, x_val), x_val)
8995

90-
x_val = np.array([1.0], dtype=pytensor.config.floatX)
96+
y_fn = pytensor.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN))
97+
# The shape doesn't depend on y so the Assert is dropped from the graph
9198
assert np.array_equal(y_fn(0, x_val), x_val)
9299

93100
y = check_and_raise(x, pt.as_tensor(0))
94-
y_grad = pytensor.grad(y.sum(), [x])
101+
y_grad = pytensor.grad(y.sum(), x)
95102
y_fn = pytensor.function([x], y_grad, mode=Mode(linker, OPT_FAST_RUN))
96-
97-
assert np.array_equal(y_fn(x_val), [x_val])
103+
# The gradient doesn't depend on y, just it's shape so the Assert is dropped from the graph
104+
assert np.array_equal(y_fn(x_val), x_val)
98105

99106

100107
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)