Skip to content

Commit c946160

Browse files
committed
Use second for broadcast_arrays and remove fill_chain helper
1 parent 74d7825 commit c946160

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytensor.scalar import upcast
2424
from pytensor.tensor import as_tensor_variable
2525
from pytensor.tensor import basic as at
26-
from pytensor.tensor import get_vector_length
26+
from pytensor.tensor.basic import get_vector_length, second
2727
from pytensor.tensor.exceptions import NotScalarConstantError
2828
from pytensor.tensor.math import abs as pt_abs
2929
from pytensor.tensor.math import all as pt_all
@@ -1780,7 +1780,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
17801780
The arrays to broadcast.
17811781
17821782
"""
1783-
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
1783+
1784+
def broadcast_with_others(a, others):
1785+
for other in others:
1786+
a = second(other, a)
1787+
return a
1788+
1789+
brodacasted_vars = []
1790+
for i, a in enumerate(args):
1791+
# We use indexing and not identity in case there are duplicated variables
1792+
others = [a for j, a in enumerate(args) if j != i]
1793+
brodacasted_vars.append(broadcast_with_others(a, others))
1794+
1795+
return brodacasted_vars
17841796

17851797

17861798
__all__ = [

pytensor/tensor/rewriting/math.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4040
from pytensor.tensor.exceptions import NotScalarConstantError
41+
from pytensor.tensor.extra_ops import broadcast_arrays
4142
from pytensor.tensor.math import (
4243
All,
4344
Any,
@@ -148,12 +149,6 @@ def get_constant(v):
148149
return v
149150

150151

151-
def fill_chain(new_out, orig_inputs):
152-
for i in orig_inputs:
153-
new_out = fill(i, new_out)
154-
return [new_out]
155-
156-
157152
@register_canonicalize
158153
@register_stabilize
159154
@node_rewriter([Dot])
@@ -1136,7 +1131,7 @@ def same(x, y):
11361131
new = cast(new, out.type.dtype)
11371132

11381133
if new.type.broadcastable != out.type.broadcastable:
1139-
new = fill_chain(new, node.inputs)[0]
1134+
new = broadcast_arrays(new, *node.inputs)[0]
11401135

11411136
if (new.type.dtype == out.type.dtype) and (
11421137
new.type.broadcastable == out.type.broadcastable
@@ -1961,7 +1956,9 @@ def local_mul_zero(fgraph, node):
19611956
# print 'MUL by value', value, node.inputs
19621957
if value == 0:
19631958
# print '... returning zeros'
1964-
return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs)
1959+
return [
1960+
broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]
1961+
]
19651962

19661963

19671964
# TODO: Add this to the canonicalization to reduce redundancy.
@@ -2260,12 +2257,12 @@ def local_add_specialize(fgraph, node):
22602257
# Reuse call to constant for cache()
22612258
cst = constant(np.zeros((1,) * ndim, dtype=dtype))
22622259
assert cst.type.broadcastable == (True,) * ndim
2263-
return fill_chain(cst, node.inputs)
2260+
return [broadcast_arrays(cst, *node.inputs)[0]]
22642261

22652262
if len(new_inputs) == 1:
2266-
ret = fill_chain(new_inputs[0], node.inputs)
2263+
ret = [broadcast_arrays(new_inputs[0], *node.inputs)[0]]
22672264
else:
2268-
ret = fill_chain(add(*new_inputs), node.inputs)
2265+
ret = [broadcast_arrays(add(*new_inputs), *node.inputs)[0]]
22692266

22702267
# The dtype should not be changed. It can happen if the input
22712268
# that was forcing upcasting was equal to 0.
@@ -2383,7 +2380,7 @@ def local_log1p(fgraph, node):
23832380
ninp = nonconsts[0]
23842381
if ninp.dtype != log_arg.type.dtype:
23852382
ninp = ninp.astype(node.outputs[0].dtype)
2386-
return fill_chain(log1p(ninp), scalar_inputs)
2383+
return [broadcast_arrays(log1p(ninp), *scalar_inputs)[0]]
23872384

23882385
elif log_arg.owner and log_arg.owner.op == sub:
23892386
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
@@ -3578,10 +3575,12 @@ def local_reciprocal_1_plus_exp(fgraph, node):
35783575
if len(nonconsts) == 1:
35793576
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
35803577
if scalars_ and np.allclose(np.sum(scalars_), 1):
3581-
out = fill_chain(
3582-
sigmoid(neg(nonconsts[0].owner.inputs[0])),
3583-
scalar_inputs,
3584-
)
3578+
out = [
3579+
broadcast_arrays(
3580+
sigmoid(neg(nonconsts[0].owner.inputs[0])),
3581+
*scalar_inputs,
3582+
)[0]
3583+
]
35853584
# keep combined stack traces of
35863585
# exp(x): nonconsts[0],
35873586
# 1 + exp(x): reciprocal_arg,

0 commit comments

Comments
 (0)