Skip to content

Commit 82823de

Browse files
committed
Remove switches when both branches are equivalent constants with different dtype
1 parent aad78d5 commit 82823de

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,25 @@ def local_sum_make_vector(fgraph, node):
958958
return [element_sum]
959959

960960

961+
def equivalent_up_to_constant_casting(a, b) -> bool:
962+
"""Return True if a and b are equivalent up to constant casting."""
963+
if a == b:
964+
return True
965+
# Return equivalence based on data values, ignoring dtype
966+
if (
967+
isinstance(a, TensorConstant)
968+
and isinstance(b, TensorConstant)
969+
and a.type.shape == b.type.shape
970+
# We don't want to spend a lot of time comparing large constant arrays
971+
# First, check if dtype matches, otherwise a == b would be true if they hold the same values
972+
and a.type.dtype != b.type.dtype
973+
# Check property sum() that is cached for TensorConstants, to filter down candidates even more
974+
and a.signature().sum == b.signature().sum
975+
):
976+
return np.array_equal(a.data, b.data)
977+
return False
978+
979+
961980
@register_useless("shape_unsafe")
962981
@register_canonicalize("fast_compile", "shape_unsafe")
963982
@register_specialize("shape_unsafe")
@@ -1004,17 +1023,19 @@ def local_useless_switch(fgraph, node):
10041023
return [out]
10051024

10061025
# if left is right -> left
1007-
if left == right:
1008-
# Note: No need to copy over stacktrace, because the input node
1009-
# already has its own stacktrace
1026+
if equivalent_up_to_constant_casting(left, right):
10101027
if left.type.broadcastable == out_bcast:
1028+
out_dtype = node.outputs[0].type.dtype
1029+
if left.type.dtype != out_dtype:
1030+
left = cast(left, out_dtype)
1031+
copy_stack_trace(node.outputs + left, left)
1032+
# When not casting, the other inputs of the switch aren't needed in the traceback
10111033
return [left]
10121034

1013-
ret = broadcast_arrays(left, cond)[0]
1014-
1015-
# Copy over stacktrace from switch output and correct branch
1016-
copy_stack_trace(node.outputs + left, ret)
1017-
return [ret]
1035+
else:
1036+
ret = broadcast_arrays(left, cond)[0]
1037+
copy_stack_trace(node.outputs + left, ret)
1038+
return [ret]
10181039

10191040
# This case happens with scan.
10201041
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)

tests/tensor/rewriting/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ScalarFromTensor,
2727
Split,
2828
TensorFromScalar,
29+
as_tensor,
2930
cast,
3031
join,
3132
tile,
@@ -983,6 +984,21 @@ def test_left_is_right(self, dtype1):
983984
assert np.array_equal(f0(vx), vx)
984985
assert np.array_equal(f2(vx, vc), vx)
985986

987+
def test_left_is_right_constant(self):
988+
int8_one = as_tensor(np.int8(1))
989+
int8_zero = as_tensor(np.int8(0))
990+
int64_zero = as_tensor(np.int64(0))
991+
cond = scalar("cond", dtype=bool)
992+
993+
out = pt.switch(cond, int8_zero, int64_zero)
994+
assert equal_computations([rewrite_graph(out)], [int64_zero])
995+
996+
out = pt.switch(cond, int64_zero, int8_zero)
997+
assert equal_computations([rewrite_graph(out)], [int64_zero])
998+
999+
out = pt.switch(cond, int8_one, int8_zero)
1000+
assert equal_computations([rewrite_graph(out)], [out])
1001+
9861002
@pytest.mark.parametrize(
9871003
"dtype1",
9881004
["float32", "float64"],

0 commit comments

Comments
 (0)