Skip to content

Commit 67519be

Browse files
committed
Rename broadcast_like to alloc_like
1 parent 316ce0b commit 67519be

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytensor.scalar.basic as aes
99
from pytensor import compile
1010
from pytensor.compile.ops import ViewOp
11+
from pytensor.graph import FunctionGraph
1112
from pytensor.graph.basic import Constant, Variable
1213
from pytensor.graph.rewriting.basic import (
1314
NodeRewriter,
@@ -87,13 +88,13 @@ def merge_broadcastables(broadcastables):
8788
return [all(bcast) for bcast in zip(*broadcastables)]
8889

8990

90-
def broadcast_like(value, template, fgraph, dtype=None):
91-
"""
92-
Return a Variable with the same shape and dtype as the template,
93-
filled by broadcasting value through it. `value` will be cast as
94-
necessary.
95-
96-
"""
91+
def alloc_like(
92+
value: TensorVariable,
93+
template: TensorVariable,
94+
fgraph: FunctionGraph,
95+
dtype=None,
96+
) -> TensorVariable:
97+
"""Fill value to the same shape and dtype as the template via alloc."""
9798
value = as_tensor_variable(value)
9899
if value.type.is_super(template.type):
99100
return value
@@ -438,7 +439,7 @@ def local_fill_to_alloc(fgraph, node):
438439
# In this case, we assume that some broadcasting is needed (otherwise
439440
# the condition above would've been true), so we replace the `fill`
440441
# with an `Alloc`.
441-
o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype)
442+
o = alloc_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype)
442443
copy_stack_trace(node.outputs[0], o)
443444
return [o]
444445

pytensor/tensor/rewriting/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pytensor.tensor.exceptions import NotScalarConstantError
3535
from pytensor.tensor.math import exp
3636
from pytensor.tensor.rewriting.basic import (
37-
broadcast_like,
37+
alloc_like,
3838
register_canonicalize,
3939
register_specialize,
4040
)
@@ -1242,7 +1242,7 @@ def local_inline_composite_constants(fgraph, node):
12421242
# Some of the inlined constants were broadcasting the output shape
12431243
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
12441244
new_outputs = [
1245-
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
1245+
alloc_like(new_out, template=node.outputs[0], fgraph=fgraph)
12461246
for new_out in new_outputs
12471247
]
12481248

pytensor/tensor/rewriting/math.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
from pytensor.tensor.math import sum as at_sum
8585
from pytensor.tensor.math import true_div
8686
from pytensor.tensor.rewriting.basic import (
87-
broadcast_like,
87+
alloc_like,
8888
broadcasted_by,
8989
local_fill_sink,
9090
register_canonicalize,
@@ -1973,7 +1973,7 @@ def local_div_to_reciprocal(fgraph, node):
19731973
new_out = cast(new_out, dtype=out.dtype)
19741974
# The ones could have forced a specific length
19751975
if not out.type.is_super(new_out.type):
1976-
new_out = broadcast_like(new_out, out, fgraph)
1976+
new_out = alloc_like(new_out, out, fgraph)
19771977
return [new_out]
19781978
else:
19791979
return False
@@ -1994,9 +1994,9 @@ def local_pow_canonicalize(fgraph, node):
19941994
if node.op == at_pow:
19951995
cst = get_constant(node.inputs[1])
19961996
if cst == 0:
1997-
return [broadcast_like(1, node.outputs[0], fgraph)]
1997+
return [alloc_like(1, node.outputs[0], fgraph)]
19981998
if cst == 1:
1999-
return [broadcast_like(node.inputs[0], node.outputs[0], fgraph)]
1999+
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
20002000
else:
20012001
return False
20022002

@@ -2033,7 +2033,7 @@ def local_zero_div(fgraph, node):
20332033
node.op.scalar_op, (aes.IntDiv, aes.TrueDiv)
20342034
):
20352035
if get_constant(node.inputs[0]) == 0:
2036-
ret = broadcast_like(0, node.outputs[0], fgraph)
2036+
ret = alloc_like(0, node.outputs[0], fgraph)
20372037
ret.tag.values_eq_approx = values_eq_approx_remove_nan
20382038
return [ret]
20392039

@@ -2184,7 +2184,7 @@ def local_mul_specialize(fgraph, node):
21842184
has_neg ^= True # toggles
21852185
elif y == 0.0:
21862186
# if we find any zero, we just return right away
2187-
return [broadcast_like(0, node.outputs[0], fgraph)]
2187+
return [alloc_like(0, node.outputs[0], fgraph)]
21882188
else:
21892189
new_inputs.append(inp)
21902190

@@ -2209,14 +2209,14 @@ def local_mul_specialize(fgraph, node):
22092209
new_inputs = [m1] + new_inputs
22102210
rval = mul(*new_inputs)
22112211

2212-
return [broadcast_like(rval, node.outputs[0], fgraph)]
2212+
return [alloc_like(rval, node.outputs[0], fgraph)]
22132213
else:
22142214
# there are no variable inputs to mul
22152215
# N.B. this could have been constant-folded...
22162216
if has_neg:
2217-
return [broadcast_like(-1, node.outputs[0], fgraph)]
2217+
return [alloc_like(-1, node.outputs[0], fgraph)]
22182218
else:
2219-
return [broadcast_like(1, node.outputs[0], fgraph)]
2219+
return [alloc_like(1, node.outputs[0], fgraph)]
22202220

22212221

22222222
@register_specialize

0 commit comments

Comments
 (0)