|
38 | 38 | )
|
39 | 39 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
40 | 40 | from pytensor.tensor.exceptions import NotScalarConstantError
|
| 41 | +from pytensor.tensor.extra_ops import broadcast_arrays |
41 | 42 | from pytensor.tensor.math import (
|
42 | 43 | All,
|
43 | 44 | Any,
|
@@ -148,12 +149,6 @@ def get_constant(v):
|
148 | 149 | return v
|
149 | 150 |
|
150 | 151 |
|
151 |
| -def fill_chain(new_out, orig_inputs): |
152 |
| - for i in orig_inputs: |
153 |
| - new_out = fill(i, new_out) |
154 |
| - return [new_out] |
155 |
| - |
156 |
| - |
157 | 152 | @register_canonicalize
|
158 | 153 | @register_stabilize
|
159 | 154 | @node_rewriter([Dot])
|
@@ -1136,7 +1131,7 @@ def same(x, y):
|
1136 | 1131 | new = cast(new, out.type.dtype)
|
1137 | 1132 |
|
1138 | 1133 | if new.type.broadcastable != out.type.broadcastable:
|
1139 |
| - new = fill_chain(new, node.inputs)[0] |
| 1134 | + new = broadcast_arrays(new, *node.inputs)[0] |
1140 | 1135 |
|
1141 | 1136 | if (new.type.dtype == out.type.dtype) and (
|
1142 | 1137 | new.type.broadcastable == out.type.broadcastable
|
@@ -1961,7 +1956,9 @@ def local_mul_zero(fgraph, node):
|
1961 | 1956 | # print 'MUL by value', value, node.inputs
|
1962 | 1957 | if value == 0:
|
1963 | 1958 | # print '... returning zeros'
|
1964 |
| - return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs) |
| 1959 | + return [ |
| 1960 | + broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0] |
| 1961 | + ] |
1965 | 1962 |
|
1966 | 1963 |
|
1967 | 1964 | # TODO: Add this to the canonicalization to reduce redundancy.
|
@@ -2260,12 +2257,12 @@ def local_add_specialize(fgraph, node):
|
2260 | 2257 | # Reuse call to constant for cache()
|
2261 | 2258 | cst = constant(np.zeros((1,) * ndim, dtype=dtype))
|
2262 | 2259 | assert cst.type.broadcastable == (True,) * ndim
|
2263 |
| - return fill_chain(cst, node.inputs) |
| 2260 | + return [broadcast_arrays(cst, *node.inputs)[0]] |
2264 | 2261 |
|
2265 | 2262 | if len(new_inputs) == 1:
|
2266 |
| - ret = fill_chain(new_inputs[0], node.inputs) |
| 2263 | + ret = [broadcast_arrays(new_inputs[0], *node.inputs)[0]] |
2267 | 2264 | else:
|
2268 |
| - ret = fill_chain(add(*new_inputs), node.inputs) |
| 2265 | + ret = [broadcast_arrays(add(*new_inputs), *node.inputs)[0]] |
2269 | 2266 |
|
2270 | 2267 | # The dtype should not be changed. It can happen if the input
|
2271 | 2268 | # that was forcing upcasting was equal to 0.
|
@@ -2383,7 +2380,7 @@ def local_log1p(fgraph, node):
|
2383 | 2380 | ninp = nonconsts[0]
|
2384 | 2381 | if ninp.dtype != log_arg.type.dtype:
|
2385 | 2382 | ninp = ninp.astype(node.outputs[0].dtype)
|
2386 |
| - return fill_chain(log1p(ninp), scalar_inputs) |
| 2383 | + return [broadcast_arrays(log1p(ninp), *scalar_inputs)[0]] |
2387 | 2384 |
|
2388 | 2385 | elif log_arg.owner and log_arg.owner.op == sub:
|
2389 | 2386 | one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
|
@@ -3578,10 +3575,12 @@ def local_reciprocal_1_plus_exp(fgraph, node):
|
3578 | 3575 | if len(nonconsts) == 1:
|
3579 | 3576 | if nonconsts[0].owner and nonconsts[0].owner.op == exp:
|
3580 | 3577 | if scalars_ and np.allclose(np.sum(scalars_), 1):
|
3581 |
| - out = fill_chain( |
3582 |
| - sigmoid(neg(nonconsts[0].owner.inputs[0])), |
3583 |
| - scalar_inputs, |
3584 |
| - ) |
| 3578 | + out = [ |
| 3579 | + broadcast_arrays( |
| 3580 | + sigmoid(neg(nonconsts[0].owner.inputs[0])), |
| 3581 | + *scalar_inputs, |
| 3582 | + )[0] |
| 3583 | + ] |
3585 | 3584 | # keep combined stack traces of
|
3586 | 3585 | # exp(x): nonconsts[0],
|
3587 | 3586 | # 1 + exp(x): reciprocal_arg,
|
|
0 commit comments