File tree Expand file tree Collapse file tree 2 files changed +17
-22
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +17
-22
lines changed Original file line number Diff line number Diff line change 49
49
from pytensor .tensor .shape import Shape_i
50
50
from pytensor .tensor .sort import TopKOp
51
51
from pytensor .tensor .type import DenseTensorType , TensorType
52
- from pytensor .tensor .var import TensorConstant
52
+ from pytensor .tensor .var import TensorConstant , TensorVariable
53
53
from pytensor .utils import NoDuplicateOptWarningFilter
54
54
55
55
61
61
_logger .addFilter (NoDuplicateOptWarningFilter ())
62
62
63
63
64
- def encompasses_broadcastable ( b1 , b2 ) :
65
- """
64
+ def broadcasted_by ( x : TensorVariable , y : TensorVariable ) -> bool :
65
+ """Check whether x would be broadcasted by y in an Elemwise operation
66
66
67
67
Parameters
68
68
----------
69
- b1
70
- The broadcastable attribute of a tensor type.
71
- b2
72
- The broadcastable attribute of a tensor type.
69
+ x: TensorVariable
70
+ The variable that may be broadcasted by y
71
+ y: TensorVariable
72
+ The variable that may broadcast x
73
73
74
74
Returns
75
75
-------
76
- bool
77
- True if the broadcastable patterns b1 and b2 are such that b2 is
78
- broadcasted to b1's shape and not the opposite.
79
-
76
+ broadcasted_by: bool
80
77
"""
81
- if len (b1 ) < len (b2 ):
82
- return False
83
- b1 = b1 [- len (b2 ) :]
84
- return not any (v1 and not v2 for v1 , v2 in zip (b1 , b2 ))
78
+ bx = x .type .broadcastable
79
+ by = y .type .broadcastable
80
+ if len (bx ) < len (by ):
81
+ return True
82
+ bx = bx [- len (by ) :]
83
+ return any (bx_dim and not by_dim for bx_dim , by_dim in zip (bx , by ))
85
84
86
85
87
86
def merge_broadcastables (broadcastables ):
Original file line number Diff line number Diff line change 85
85
from pytensor .tensor .math import true_div
86
86
from pytensor .tensor .rewriting .basic import (
87
87
broadcast_like ,
88
- encompasses_broadcastable ,
88
+ broadcasted_by ,
89
89
local_fill_sink ,
90
90
register_canonicalize ,
91
91
register_specialize ,
@@ -2049,9 +2049,7 @@ def local_pow_specialize(fgraph, node):
2049
2049
xsym = node .inputs [0 ]
2050
2050
ysym = node .inputs [1 ]
2051
2051
y = get_constant (ysym )
2052
- if (y is not None ) and encompasses_broadcastable (
2053
- xsym .type .broadcastable , ysym .type .broadcastable
2054
- ):
2052
+ if (y is not None ) and not broadcasted_by (xsym , ysym ):
2055
2053
rval = None
2056
2054
2057
2055
if np .all (y == 2 ):
@@ -2107,9 +2105,7 @@ def local_pow_to_nested_squaring(fgraph, node):
2107
2105
y = y [0 ]
2108
2106
except IndexError :
2109
2107
pass
2110
- if (y is not None ) and encompasses_broadcastable (
2111
- xsym .type .broadcastable , ysym .type .broadcastable
2112
- ):
2108
+ if (y is not None ) and not broadcasted_by (xsym , ysym ):
2113
2109
rval = None
2114
2110
# 512 is too small for the cpu and too big for some gpu!
2115
2111
if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
You can’t perform that action at this time.
0 commit comments