Skip to content

Commit 316ce0b

Browse files
committed
Refactor encompasses_broadcastable to broadcasted_by
1 parent c946160 commit 316ce0b

File tree

2 files changed

+17
-22
lines changed

2 files changed

+17
-22
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from pytensor.tensor.shape import Shape_i
5050
from pytensor.tensor.sort import TopKOp
5151
from pytensor.tensor.type import DenseTensorType, TensorType
52-
from pytensor.tensor.var import TensorConstant
52+
from pytensor.tensor.var import TensorConstant, TensorVariable
5353
from pytensor.utils import NoDuplicateOptWarningFilter
5454

5555

@@ -61,27 +61,26 @@
6161
_logger.addFilter(NoDuplicateOptWarningFilter())
6262

6363

64-
def encompasses_broadcastable(b1, b2):
65-
"""
64+
def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
65+
"""Check whether x would be broadcasted by y in an Elemwise operation
6666
6767
Parameters
6868
----------
69-
b1
70-
The broadcastable attribute of a tensor type.
71-
b2
72-
The broadcastable attribute of a tensor type.
69+
x: TensorVariable
70+
The variable that may be broadcasted by y
71+
y: TensorVariable
72+
The variable that may broadcast x
7373
7474
Returns
7575
-------
76-
bool
77-
True if the broadcastable patterns b1 and b2 are such that b2 is
78-
broadcasted to b1's shape and not the opposite.
79-
76+
broadcasted_by: bool
8077
"""
81-
if len(b1) < len(b2):
82-
return False
83-
b1 = b1[-len(b2) :]
84-
return not any(v1 and not v2 for v1, v2 in zip(b1, b2))
78+
bx = x.type.broadcastable
79+
by = y.type.broadcastable
80+
if len(bx) < len(by):
81+
return True
82+
bx = bx[-len(by) :]
83+
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by))
8584

8685

8786
def merge_broadcastables(broadcastables):

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
from pytensor.tensor.math import true_div
8686
from pytensor.tensor.rewriting.basic import (
8787
broadcast_like,
88-
encompasses_broadcastable,
88+
broadcasted_by,
8989
local_fill_sink,
9090
register_canonicalize,
9191
register_specialize,
@@ -2049,9 +2049,7 @@ def local_pow_specialize(fgraph, node):
20492049
xsym = node.inputs[0]
20502050
ysym = node.inputs[1]
20512051
y = get_constant(ysym)
2052-
if (y is not None) and encompasses_broadcastable(
2053-
xsym.type.broadcastable, ysym.type.broadcastable
2054-
):
2052+
if (y is not None) and not broadcasted_by(xsym, ysym):
20552053
rval = None
20562054

20572055
if np.all(y == 2):
@@ -2107,9 +2105,7 @@ def local_pow_to_nested_squaring(fgraph, node):
21072105
y = y[0]
21082106
except IndexError:
21092107
pass
2110-
if (y is not None) and encompasses_broadcastable(
2111-
xsym.type.broadcastable, ysym.type.broadcastable
2112-
):
2108+
if (y is not None) and not broadcasted_by(xsym, ysym):
21132109
rval = None
21142110
# 512 is too small for the cpu and too big for some gpu!
21152111
if abs(y) == int(abs(y)) and abs(y) <= 512:

0 commit comments

Comments
 (0)