@@ -958,6 +958,25 @@ def local_sum_make_vector(fgraph, node):
958
958
return [element_sum ]
959
959
960
960
961
+ def equivalent_up_to_constant_casting (a , b ) -> bool :
962
+ """Return True if a and b are equivalent up to constant casting."""
963
+ if a == b :
964
+ return True
965
+ # Return equivalence based on data values, ignoring dtype
966
+ if (
967
+ isinstance (a , TensorConstant )
968
+ and isinstance (b , TensorConstant )
969
+ and a .type .shape == b .type .shape
970
+ # We don't want to spend a lot of time comparing large constant arrays
971
+ # First, check if dtype matches, otherwise a == b would be true if they hold the same values
972
+ and a .type .dtype != b .type .dtype
973
+ # Check property sum() that is cached for TensorConstants, to filter down candidates even more
974
+ and a .signature ().sum == b .signature ().sum
975
+ ):
976
+ return np .array_equal (a .data , b .data )
977
+ return False
978
+
979
+
961
980
@register_useless ("shape_unsafe" )
962
981
@register_canonicalize ("fast_compile" , "shape_unsafe" )
963
982
@register_specialize ("shape_unsafe" )
@@ -1004,17 +1023,19 @@ def local_useless_switch(fgraph, node):
1004
1023
return [out ]
1005
1024
1006
1025
# if left is right -> left
1007
- if left == right :
1008
- # Note: No need to copy over stacktrace, because the input node
1009
- # already has its own stacktrace
1026
+ if equivalent_up_to_constant_casting (left , right ):
1010
1027
if left .type .broadcastable == out_bcast :
1028
+ out_dtype = node .outputs [0 ].type .dtype
1029
+ if left .type .dtype != out_dtype :
1030
+ left = cast (left , out_dtype )
1031
+ copy_stack_trace (node .outputs + left , left )
1032
+ # When not casting, the other inputs of the switch aren't needed in the traceback
1011
1033
return [left ]
1012
1034
1013
- ret = broadcast_arrays (left , cond )[0 ]
1014
-
1015
- # Copy over stacktrace from switch output and correct branch
1016
- copy_stack_trace (node .outputs + left , ret )
1017
- return [ret ]
1035
+ else :
1036
+ ret = broadcast_arrays (left , cond )[0 ]
1037
+ copy_stack_trace (node .outputs + left , ret )
1038
+ return [ret ]
1018
1039
1019
1040
# This case happens with scan.
1020
1041
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
0 commit comments