|
76 | 76 | sub, |
77 | 77 | tri_gamma, |
78 | 78 | true_div, |
| 79 | + variadic_add, |
| 80 | + variadic_mul, |
79 | 81 | ) |
80 | 82 | from pytensor.tensor.math import abs as pt_abs |
81 | 83 | from pytensor.tensor.math import max as pt_max |
@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node): |
1270 | 1272 |
|
1271 | 1273 | if not outer_terms: |
1272 | 1274 | return None |
1273 | | - elif len(outer_terms) == 1: |
1274 | | - [outer_term] = outer_terms |
1275 | 1275 | else: |
1276 | | - outer_term = mul(*outer_terms) |
| 1276 | + outer_term = variadic_mul(*outer_terms) |
1277 | 1277 |
|
1278 | 1278 | if not inner_terms: |
1279 | 1279 | inner_term = None |
1280 | | - elif len(inner_terms) == 1: |
1281 | | - [inner_term] = inner_terms |
1282 | 1280 | else: |
1283 | | - inner_term = mul(*inner_terms) |
| 1281 | + inner_term = variadic_mul(*inner_terms) |
1284 | 1282 |
|
1285 | 1283 | else: # true_div |
1286 | 1284 | # We only care about removing the denominator out of the reduction |
@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node): |
2143 | 2141 | assert cst.type.broadcastable == (True,) * ndim |
2144 | 2142 | return [alloc_like(cst, node_output, fgraph)] |
2145 | 2143 |
|
2146 | | - if len(new_inputs) == 1: |
2147 | | - ret = [alloc_like(new_inputs[0], node_output, fgraph)] |
2148 | | - else: |
2149 | | - ret = [alloc_like(add(*new_inputs), node_output, fgraph)] |
| 2144 | + ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)] |
2150 | 2145 |
|
2151 | 2146 | # The dtype should not be changed. It can happen if the input |
2152 | 2147 | # that was forcing upcasting was equal to 0. |
@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node): |
2257 | 2252 | # scalar_inputs are potentially dimshuffled and fill'd scalars |
2258 | 2253 | if scalars and np.allclose(np.sum(scalars), 1): |
2259 | 2254 | if nonconsts: |
2260 | | - if len(nonconsts) > 1: |
2261 | | - ninp = add(*nonconsts) |
2262 | | - else: |
2263 | | - ninp = nonconsts[0] |
| 2255 | + ninp = variadic_add(*nonconsts) |
2264 | 2256 | if ninp.dtype != log_arg.type.dtype: |
2265 | 2257 | ninp = ninp.astype(node.outputs[0].dtype) |
2266 | 2258 | return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] |
@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node): |
3084 | 3076 | return |
3085 | 3077 | # put the new numerator together |
3086 | 3078 | new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest |
3087 | | - if len(new_num) == 1: |
3088 | | - new_num = new_num[0] |
3089 | | - else: |
3090 | | - new_num = mul(*new_num) |
| 3079 | + new_num = variadic_mul(*new_num) |
3091 | 3080 |
|
3092 | 3081 | if num_neg ^ denom_neg: |
3093 | 3082 | new_num = -new_num |
|
0 commit comments