diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index b41b5f460d..b7c2c52ee4 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -273,7 +273,7 @@ def grad(self, ins, grads): # `condition` does affect the elements of the output so it is connected. # For the sake of making the gradient convenient we assume that # condition + epsilon always triggers the same branch as condition - condition_grad = condition.zeros_like().astype(config.floatX) + condition_grad = condition.zeros_like(dtype=config.floatX) return [ condition_grad, diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 714c8fd7bf..985cc1bc6a 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1323,8 +1323,8 @@ def L_op(self, inputs, outputs, output_gradients): x, y = inputs assert outputs[0].type == bool return [ - x.zeros_like().astype(config.floatX), - y.zeros_like().astype(config.floatX), + x.zeros_like(dtype=config.floatX), + y.zeros_like(dtype=config.floatX), ] def c_code_cache_version(self): @@ -1358,7 +1358,7 @@ def output_types(self, *input_dtypes): def L_op(self, inputs, outputs, output_gradients): (x,) = inputs assert outputs[0].type == bool - return [x.zeros_like().astype(config.floatX)] + return [x.zeros_like(dtype=config.floatX)] def c_code_cache_version(self): super_version = super().c_code_cache_version() @@ -1577,7 +1577,7 @@ def get_grad(self, elem): ) raise NotImplementedError(msg) elif elem.type in discrete_types: - return elem.zeros_like().astype(config.floatX) + return elem.zeros_like(dtype=config.floatX) else: return elem.zeros_like() @@ -1611,13 +1611,13 @@ def L_op(self, inputs, outputs, gout): second_part = switch(cond, 0.0, gz) if outputs[0].type in discrete_types: - first_part = ift.zeros_like(config.floatX) - second_part = iff.zeros_like(config.floatX) + first_part = ift.zeros_like(dtype=config.floatX) + second_part = iff.zeros_like(dtype=config.floatX) # cond does affect the elements of the output so it is connected. # For the sake of making the gradient convenient we assume that # condition + epsilon always triggers the same branch as condition - condition_grad = cond.zeros_like().astype(config.floatX) + condition_grad = cond.zeros_like(dtype=config.floatX) return (condition_grad, first_part, second_part) @@ -1644,7 +1644,7 @@ def output_types(self, *input_types): return upcast_out(*input_types[0]) def grad(self, inputs, output_gradients): - return [inputs[0].zeros_like().astype(config.floatX)] + return [inputs[0].zeros_like(dtype=config.floatX)] class BinaryBitOp(BinaryScalarOp): @@ -1664,8 +1664,8 @@ def output_types(self, *input_types): def grad(self, inputs, output_gradients): a, b = inputs return [ - a.zeros_like().astype(config.floatX), - b.zeros_like().astype(config.floatX), + a.zeros_like(dtype=config.floatX), + b.zeros_like(dtype=config.floatX), ] @@ -1776,8 +1776,8 @@ def L_op(self, inputs, outputs, gout): if outputs[0].type in discrete_types: return [ - x.zeros_like().astype(config.floatX), - y.zeros_like().astype(config.floatX), + x.zeros_like(dtype=config.floatX), + y.zeros_like(dtype=config.floatX), ] # This form handle the case when both value are the same. # In that case, gx will be gz, gy will be 0. @@ -1818,8 +1818,8 @@ def L_op(self, inputs, outputs, gout): if outputs[0].type in discrete_types: return [ - x.zeros_like().astype(config.floatX), - y.zeros_like().astype(config.floatX), + x.zeros_like(dtype=config.floatX), + y.zeros_like(dtype=config.floatX), ] # This form handle the case when both value are the same. # In that case, gx will be gz, gy will be 0. @@ -1861,7 +1861,7 @@ def L_op(self, inputs, outputs, gout): retval = [] for ii, inp in enumerate(inputs): if hasattr(inp, "zeros_like"): - retval.append(inp.zeros_like().astype(config.floatX)) + retval.append(inp.zeros_like(dtype=config.floatX)) else: retval.append(grad_undefined(self, ii, inp)) else: @@ -1937,7 +1937,7 @@ def grad(self, inputs, gout): ) if output_type in discrete_types: - return [ipt.zeros_like().astype(config.floatX) for ipt in inputs] + return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs] for input in inputs: if gz.type in complex_types: @@ -1980,8 +1980,8 @@ def L_op(self, inputs, outputs, gout): raise NotImplementedError() if outputs[0].type in discrete_types: return [ - x.zeros_like().astype(config.floatX), - y.zeros_like().astype(config.floatX), + x.zeros_like(dtype=config.floatX), + y.zeros_like(dtype=config.floatX), ] first_part = gz @@ -2036,7 +2036,10 @@ def grad(self, inputs, gout): # to the output; x/y is still a function of x # and y; it's just a step function. if all(a.dtype in discrete_dtypes for a in (x, y)): - return [x.zeros_like(), y.zeros_like()] + return [ + x.zeros_like(dtype=config.floatX), + y.zeros_like(dtype=config.floatX), + ] first_part = gz / y @@ -2293,8 +2296,8 @@ def L_op(self, inputs, outputs, gout): if outputs[0].type in discrete_types: return [ - x.zeros_like().astype(config.floatX), - y.zeros_like().astype(config.floatX), + x.zeros_like(dtype=config.floatX), + y.zeros_like(dtype=config.floatX), ] first_part = gz * y * x ** (y - 1) @@ -2385,7 +2388,7 @@ def L_op(self, inputs, outputs, gout): def handle_int(v): if outputs[0].type in int_types: - return v.zeros_like().astype(config.floatX) + return v.zeros_like(dtype=config.floatX) return v return list(map(handle_int, [gx, gmn, gmx])) @@ -2422,7 +2425,7 @@ def grad(self, inputs, gout): # to deal with real-valued inputs by rounding them to the # nearest integer. f(x+eps) thus equals f(x) so the gradient # is zero, not disconnected or undefined - return DisconnectedType()(), y.zeros_like() + return DisconnectedType()(), y.zeros_like(dtype=config.floatX) second = Second(transfer_type(1), name="second") @@ -2494,7 +2497,7 @@ def grad(self, inputs, gout): if self.o_type in continuous_types: return [gz] else: - return [x.zeros_like().astype(config.floatX)] + return [x.zeros_like(dtype=config.floatX)] def c_code_cache_version(self): s = super().c_code_cache_version() @@ -2715,7 +2718,7 @@ def impl(self, x): def grad(self, inputs, gout): (x,) = inputs (gz,) = gout - return [x.zeros_like().astype(config.floatX)] + return [x.zeros_like(dtype=config.floatX)] def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 9295a130c2..d3ac61ae02 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -589,7 +589,7 @@ def grad(self, inp, grads): # Currently, pytensor.grad insists that the dtype of the returned # gradient has a float dtype, so we use floatX. if s.type.dtype in discrete_dtypes: - return [s.zeros_like().astype(config.floatX)] + return [s.zeros_like(dtype=config.floatX)] raise NotImplementedError("grad not implemented for complex dtypes") @@ -1876,7 +1876,7 @@ def infer_shape(self, fgraph, node, ishapes): def grad(self, inputs, output_gradients): # If the output is of an integer dtype, no gradient shall pass if self.dtype in discrete_dtypes: - return [ipt.zeros_like().astype(config.floatX) for ipt in inputs] + return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs] grads = [output_gradients[0][i] for i in range(len(inputs))] return grads diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 22a08718ae..b3cf96cbd4 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -102,7 +102,7 @@ from pytensor.tensor.basic import expand_dims from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import add, mul, neg, sub +from pytensor.tensor.math import add, mul, neg, sub, variadic_add from pytensor.tensor.shape import shape_padright, specify_broadcastable from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor @@ -1399,11 +1399,7 @@ def item_to_var(t): item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) ] add_inputs.extend(gemm_of_sM_list) - if len(add_inputs) > 1: - rval = [add(*add_inputs)] - else: - rval = add_inputs - # print "RETURNING GEMM THING", rval + rval = [variadic_add(*add_inputs)] return rval, old_dot22 diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 1ad9ce0158..7defb355ef 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1430,18 +1430,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None) else: shp = cast(shp, "float64") - if axis is None: - axis = list(range(input.ndim)) - elif isinstance(axis, int | np.integer): - axis = [axis] - elif isinstance(axis, np.ndarray) and axis.ndim == 0: - axis = [int(axis)] - else: - axis = [int(a) for a in axis] - - # This sequential division will possibly be optimized by PyTensor: - for i in axis: - s = true_div(s, shp[i]) + reduced_dims = ( + shp + if axis is None + else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)] + ) + s /= variadic_mul(*reduced_dims).astype(shp.dtype) # This can happen when axis is an empty list/tuple if s.dtype != shp.dtype and s.dtype in discrete_dtypes: @@ -1597,6 +1591,15 @@ def add(a, *other_terms): # see decorator for function body +def variadic_add(*args): + """Add that accepts arbitrary number of inputs, including zero or one.""" + if not args: + return constant(0) + if len(args) == 1: + return args[0] + return add(*args) + + @scalar_elemwise def sub(a, b): """elementwise subtraction""" @@ -1609,6 +1612,15 @@ def mul(a, *other_terms): # see decorator for function body +def variadic_mul(*args): + """Mul that accepts arbitrary number of inputs, including zero or one.""" + if not args: + return constant(1) + if len(args) == 1: + return args[0] + return mul(*args) + + @scalar_elemwise def true_div(a, b): """elementwise [true] division (inverse of multiplication)""" diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 6a038cab15..23618f536e 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -68,7 +68,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays -from pytensor.tensor.math import Sum, add, eq +from pytensor.tensor.math import Sum, add, eq, variadic_add from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node): if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64": return - if len(elements) == 0: - element_sum = zeros(dtype=out_dtype, shape=()) - elif len(elements) == 1: - element_sum = cast(elements[0], out_dtype) - else: - element_sum = cast( - add(*[cast(value, acc_dtype) for value in elements]), out_dtype - ) + element_sum = cast( + variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype + ) return [element_sum] diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index cc8dd472e6..d52ee70e17 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -96,7 +96,15 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub +from pytensor.tensor.math import ( + Dot, + _matrix_matrix_matmul, + add, + mul, + neg, + sub, + variadic_add, +) from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.type import ( DenseTensorType, @@ -386,10 +394,7 @@ def item_to_var(t): item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) ] add_inputs.extend(gemm_of_sM_list) - if len(add_inputs) > 1: - rval = [add(*add_inputs)] - else: - rval = add_inputs + rval = [variadic_add(*add_inputs)] # print "RETURNING GEMM THING", rval return rval, old_dot22 diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 75dba82d97..ccad9aa517 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -81,6 +81,8 @@ sub, tri_gamma, true_div, + variadic_add, + variadic_mul, ) from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import max as pt_max @@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node): if not outer_terms: return None - elif len(outer_terms) == 1: - [outer_term] = outer_terms else: - outer_term = mul(*outer_terms) + outer_term = variadic_mul(*outer_terms) if not inner_terms: inner_term = None - elif len(inner_terms) == 1: - [inner_term] = inner_terms else: - inner_term = mul(*inner_terms) + inner_term = variadic_mul(*inner_terms) else: # true_div # We only care about removing the denominator out of the reduction @@ -2163,10 +2161,7 @@ def local_add_remove_zeros(fgraph, node): assert cst.type.broadcastable == (True,) * ndim return [alloc_like(cst, node_output, fgraph)] - if len(new_inputs) == 1: - ret = [alloc_like(new_inputs[0], node_output, fgraph)] - else: - ret = [alloc_like(add(*new_inputs), node_output, fgraph)] + ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)] # The dtype should not be changed. It can happen if the input # that was forcing upcasting was equal to 0. @@ -2277,10 +2272,7 @@ def local_log1p(fgraph, node): # scalar_inputs are potentially dimshuffled and fill'd scalars if scalars and np.allclose(np.sum(scalars), 1): if nonconsts: - if len(nonconsts) > 1: - ninp = add(*nonconsts) - else: - ninp = nonconsts[0] + ninp = variadic_add(*nonconsts) if ninp.dtype != log_arg.type.dtype: ninp = ninp.astype(node.outputs[0].dtype) return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] @@ -3104,10 +3096,7 @@ def local_exp_over_1_plus_exp(fgraph, node): return # put the new numerator together new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest - if len(new_num) == 1: - new_num = new_num[0] - else: - new_num = mul(*new_num) + new_num = variadic_mul(*new_num) if num_neg ^ denom_neg: new_num = -new_num diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index f234b46804..0e7f9cc3f1 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -48,6 +48,7 @@ maximum, minimum, or_, + variadic_add, ) from pytensor.tensor.math import all as pt_all from pytensor.tensor.rewriting.basic import ( @@ -1241,15 +1242,11 @@ def movable(i): new_inputs = [i for i in node.inputs if not movable(i)] + [ mi.owner.inputs[0] for mi in movable_inputs ] - if len(new_inputs) == 0: - new_add = new_inputs[0] - else: - new_add = add(*new_inputs) - - # Copy over stacktrace from original output, as an error - # (e.g. an index error) in this add operation should - # correspond to an error in the original add operation. - copy_stack_trace(node.outputs[0], new_add) + new_add = variadic_add(*new_inputs) + # Copy over stacktrace from original output, as an error + # (e.g. an index error) in this add operation should + # correspond to an error in the original add operation. + copy_stack_trace(node.outputs[0], new_add) # stack up the new incsubtensors tip = new_add diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index e40c308f35..aa47a3415c 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -946,7 +946,7 @@ def grad(self, inputs, grads): x = inputs[0] rest = inputs[1:] if x.dtype in discrete_dtypes: - first = x.zeros_like().astype(config.floatX) + first = x.zeros_like(dtype=config.floatX) else: # For best optimization, we let this as an inc. # This allow the opt local_IncSubtensor_serialize to apply first. diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 54f93570d4..d793834817 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3210,52 +3210,56 @@ def test_mean_default_dtype(self): # TODO FIXME: This is a bad test f(data) - @pytest.mark.slow - def test_mean_custom_dtype(self): + @pytest.mark.parametrize( + "input_dtype", + ( + "bool", + "uint16", + "int8", + "int64", + "float16", + "float32", + "float64", + "complex64", + "complex128", + ), + ) + @pytest.mark.parametrize( + "sum_dtype", + ( + "bool", + "uint16", + "int8", + "int64", + "float16", + "float32", + "float64", + "complex64", + "complex128", + ), + ) + @pytest.mark.parametrize("axis", [None, ()]) + def test_mean_custom_dtype(self, input_dtype, sum_dtype, axis): # Test the ability to provide your own output dtype for a mean. - # We try multiple axis combinations even though axis should not matter. - axes = [None, 0, 1, [], [0], [1], [0, 1]] - idx = 0 - for input_dtype in map(str, ps.all_types): - x = matrix(dtype=input_dtype) - for sum_dtype in map(str, ps.all_types): - axis = axes[idx % len(axes)] - # If the inner sum cannot be created, it will raise a - # TypeError. - try: - mean_var = x.mean(dtype=sum_dtype, axis=axis) - except TypeError: - pass - else: - # Executed if no TypeError was raised - if sum_dtype in discrete_dtypes: - assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype) - else: - assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype) - if ( - "complex" in input_dtype or "complex" in sum_dtype - ) and input_dtype != sum_dtype: - continue - f = function([x], mean_var) - data = np.random.random((3, 4)) * 10 - data = data.astype(input_dtype) - # TODO FIXME: This is a bad test - f(data) - # Check that we can take the gradient, when implemented - if "complex" in mean_var.dtype: - continue - try: - grad(mean_var.sum(), x, disconnected_inputs="ignore") - except NotImplementedError: - # TrueDiv does not seem to have a gradient when - # the numerator is complex. - if mean_var.dtype in complex_dtypes: - pass - else: - raise + x = matrix(dtype=input_dtype) + # If the inner sum cannot be created, it will raise a TypeError. + mean_var = x.mean(dtype=sum_dtype, axis=axis) + if sum_dtype in discrete_dtypes: + assert mean_var.dtype == "float64", (mean_var.dtype, sum_dtype) + else: + assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype) - idx += 1 + f = function([x], mean_var, mode="FAST_COMPILE") + data = np.ones((2, 1)).astype(input_dtype) + if axis != (): + expected_res = np.array(2).astype(sum_dtype) / 2 + else: + expected_res = data + np.testing.assert_allclose(f(data), expected_res) + + if "complex" not in mean_var.dtype: + grad(mean_var.sum(), x, disconnected_inputs="ignore") def test_mean_precision(self): # Check that the default accumulator precision is sufficient