Skip to content

Commit 2e545cd

Browse files
committed
Simplify logic with variadic_add and variadic_mul helpers
1 parent b2e0205 commit 2e545cd

File tree

6 files changed

+53
-59
lines changed

6 files changed

+53
-59
lines changed

pytensor/tensor/blas.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
from pytensor.tensor.basic import expand_dims
103103
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
104104
from pytensor.tensor.elemwise import DimShuffle
105-
from pytensor.tensor.math import add, mul, neg, sub
105+
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
106106
from pytensor.tensor.shape import shape_padright, specify_broadcastable
107107
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
108108

@@ -1399,11 +1399,7 @@ def item_to_var(t):
13991399
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
14001400
]
14011401
add_inputs.extend(gemm_of_sM_list)
1402-
if len(add_inputs) > 1:
1403-
rval = [add(*add_inputs)]
1404-
else:
1405-
rval = add_inputs
1406-
# print "RETURNING GEMM THING", rval
1402+
rval = [variadic_add(*add_inputs)]
14071403
return rval, old_dot22
14081404

14091405

pytensor/tensor/math.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,18 +1430,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
14301430
else:
14311431
shp = cast(shp, "float64")
14321432

1433-
if axis is None:
1434-
axis = list(range(input.ndim))
1435-
elif isinstance(axis, int | np.integer):
1436-
axis = [axis]
1437-
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
1438-
axis = [int(axis)]
1439-
else:
1440-
axis = [int(a) for a in axis]
1441-
1442-
# This sequential division will possibly be optimized by PyTensor:
1443-
for i in axis:
1444-
s = true_div(s, shp[i])
1433+
reduced_dims = (
1434+
shp
1435+
if axis is None
1436+
else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)]
1437+
)
1438+
s /= variadic_mul(*reduced_dims)
14451439

14461440
# This can happen when axis is an empty list/tuple
14471441
if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
@@ -1597,6 +1591,15 @@ def add(a, *other_terms):
15971591
# see decorator for function body
15981592

15991593

1594+
def variadic_add(*args):
1595+
"""Add that accepts arbitrary number of inputs, including zero or one."""
1596+
if not args:
1597+
return 0
1598+
if len(args) == 1:
1599+
return args[0]
1600+
return add(*args)
1601+
1602+
16001603
@scalar_elemwise
16011604
def sub(a, b):
16021605
"""elementwise subtraction"""
@@ -1609,6 +1612,15 @@ def mul(a, *other_terms):
16091612
# see decorator for function body
16101613

16111614

1615+
def variadic_mul(*args):
1616+
"""Mul that accepts arbitrary number of inputs, including zero or one."""
1617+
if not args:
1618+
return 1
1619+
if len(args) == 1:
1620+
return args[0]
1621+
return mul(*args)
1622+
1623+
16121624
@scalar_elemwise
16131625
def true_div(a, b):
16141626
"""elementwise [true] division (inverse of multiplication)"""

pytensor/tensor/rewriting/basic.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from pytensor.tensor.elemwise import DimShuffle, Elemwise
6969
from pytensor.tensor.exceptions import NotScalarConstantError
7070
from pytensor.tensor.extra_ops import broadcast_arrays
71-
from pytensor.tensor.math import Sum, add, eq
71+
from pytensor.tensor.math import Sum, add, eq, variadic_add
7272
from pytensor.tensor.shape import Shape_i, shape_padleft
7373
from pytensor.tensor.type import DenseTensorType, TensorType
7474
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node):
939939
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
940940
return
941941

942-
if len(elements) == 0:
943-
element_sum = zeros(dtype=out_dtype, shape=())
944-
elif len(elements) == 1:
945-
element_sum = cast(elements[0], out_dtype)
946-
else:
947-
element_sum = cast(
948-
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
949-
)
942+
element_sum = cast(
943+
variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype
944+
)
950945

951946
return [element_sum]
952947

pytensor/tensor/rewriting/blas.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,15 @@
9696
)
9797
from pytensor.tensor.elemwise import DimShuffle, Elemwise
9898
from pytensor.tensor.exceptions import NotScalarConstantError
99-
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
99+
from pytensor.tensor.math import (
100+
Dot,
101+
_matrix_matrix_matmul,
102+
add,
103+
mul,
104+
neg,
105+
sub,
106+
variadic_add,
107+
)
100108
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
101109
from pytensor.tensor.type import (
102110
DenseTensorType,
@@ -386,10 +394,7 @@ def item_to_var(t):
386394
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
387395
]
388396
add_inputs.extend(gemm_of_sM_list)
389-
if len(add_inputs) > 1:
390-
rval = [add(*add_inputs)]
391-
else:
392-
rval = add_inputs
397+
rval = [variadic_add(*add_inputs)]
393398
# print "RETURNING GEMM THING", rval
394399
return rval, old_dot22
395400

pytensor/tensor/rewriting/math.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
sub,
8282
tri_gamma,
8383
true_div,
84+
variadic_add,
85+
variadic_mul,
8486
)
8587
from pytensor.tensor.math import abs as pt_abs
8688
from pytensor.tensor.math import max as pt_max
@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
12701272

12711273
if not outer_terms:
12721274
return None
1273-
elif len(outer_terms) == 1:
1274-
[outer_term] = outer_terms
12751275
else:
1276-
outer_term = mul(*outer_terms)
1276+
outer_term = variadic_mul(*outer_terms)
12771277

12781278
if not inner_terms:
12791279
inner_term = None
1280-
elif len(inner_terms) == 1:
1281-
[inner_term] = inner_terms
12821280
else:
1283-
inner_term = mul(*inner_terms)
1281+
inner_term = variadic_mul(*inner_terms)
12841282

12851283
else: # true_div
12861284
# We only care about removing the denominator out of the reduction
@@ -2163,10 +2161,7 @@ def local_add_remove_zeros(fgraph, node):
21632161
assert cst.type.broadcastable == (True,) * ndim
21642162
return [alloc_like(cst, node_output, fgraph)]
21652163

2166-
if len(new_inputs) == 1:
2167-
ret = [alloc_like(new_inputs[0], node_output, fgraph)]
2168-
else:
2169-
ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
2164+
ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)]
21702165

21712166
# The dtype should not be changed. It can happen if the input
21722167
# that was forcing upcasting was equal to 0.
@@ -2277,10 +2272,7 @@ def local_log1p(fgraph, node):
22772272
# scalar_inputs are potentially dimshuffled and fill'd scalars
22782273
if scalars and np.allclose(np.sum(scalars), 1):
22792274
if nonconsts:
2280-
if len(nonconsts) > 1:
2281-
ninp = add(*nonconsts)
2282-
else:
2283-
ninp = nonconsts[0]
2275+
ninp = variadic_add(*nonconsts)
22842276
if ninp.dtype != log_arg.type.dtype:
22852277
ninp = ninp.astype(node.outputs[0].dtype)
22862278
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
@@ -3104,10 +3096,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
31043096
return
31053097
# put the new numerator together
31063098
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
3107-
if len(new_num) == 1:
3108-
new_num = new_num[0]
3109-
else:
3110-
new_num = mul(*new_num)
3099+
new_num = variadic_mul(*new_num)
31113100

31123101
if num_neg ^ denom_neg:
31133102
new_num = -new_num

pytensor/tensor/rewriting/subtensor.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
maximum,
4949
minimum,
5050
or_,
51+
variadic_add,
5152
)
5253
from pytensor.tensor.math import all as pt_all
5354
from pytensor.tensor.rewriting.basic import (
@@ -1241,15 +1242,11 @@ def movable(i):
12411242
new_inputs = [i for i in node.inputs if not movable(i)] + [
12421243
mi.owner.inputs[0] for mi in movable_inputs
12431244
]
1244-
if len(new_inputs) == 0:
1245-
new_add = new_inputs[0]
1246-
else:
1247-
new_add = add(*new_inputs)
1248-
1249-
# Copy over stacktrace from original output, as an error
1250-
# (e.g. an index error) in this add operation should
1251-
# correspond to an error in the original add operation.
1252-
copy_stack_trace(node.outputs[0], new_add)
1245+
new_add = variadic_add(*new_inputs)
1246+
# Copy over stacktrace from original output, as an error
1247+
# (e.g. an index error) in this add operation should
1248+
# correspond to an error in the original add operation.
1249+
copy_stack_trace(node.outputs[0], new_add)
12531250

12541251
# stack up the new incsubtensors
12551252
tip = new_add

0 commit comments

Comments
 (0)