Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def grad(self, ins, grads):
# `condition` does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX)
condition_grad = condition.zeros_like(dtype=config.floatX)

return [
condition_grad,
Expand Down
53 changes: 28 additions & 25 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,8 +1323,8 @@
x, y = inputs
assert outputs[0].type == bool
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

def c_code_cache_version(self):
Expand Down Expand Up @@ -1358,7 +1358,7 @@
def L_op(self, inputs, outputs, output_gradients):
(x,) = inputs
assert outputs[0].type == bool
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]

Check warning on line 1361 in pytensor/scalar/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/basic.py#L1361

Added line #L1361 was not covered by tests

def c_code_cache_version(self):
super_version = super().c_code_cache_version()
Expand Down Expand Up @@ -1577,7 +1577,7 @@
)
raise NotImplementedError(msg)
elif elem.type in discrete_types:
return elem.zeros_like().astype(config.floatX)
return elem.zeros_like(dtype=config.floatX)

Check warning on line 1580 in pytensor/scalar/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/basic.py#L1580

Added line #L1580 was not covered by tests
else:
return elem.zeros_like()

Expand Down Expand Up @@ -1611,13 +1611,13 @@
second_part = switch(cond, 0.0, gz)

if outputs[0].type in discrete_types:
first_part = ift.zeros_like(config.floatX)
second_part = iff.zeros_like(config.floatX)
first_part = ift.zeros_like(dtype=config.floatX)
second_part = iff.zeros_like(dtype=config.floatX)

# cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = cond.zeros_like().astype(config.floatX)
condition_grad = cond.zeros_like(dtype=config.floatX)

return (condition_grad, first_part, second_part)

Expand All @@ -1644,7 +1644,7 @@
return upcast_out(*input_types[0])

def grad(self, inputs, output_gradients):
return [inputs[0].zeros_like().astype(config.floatX)]
return [inputs[0].zeros_like(dtype=config.floatX)]

Check warning on line 1647 in pytensor/scalar/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/basic.py#L1647

Added line #L1647 was not covered by tests


class BinaryBitOp(BinaryScalarOp):
Expand All @@ -1664,8 +1664,8 @@
def grad(self, inputs, output_gradients):
a, b = inputs
return [
a.zeros_like().astype(config.floatX),
b.zeros_like().astype(config.floatX),
a.zeros_like(dtype=config.floatX),
b.zeros_like(dtype=config.floatX),
]


Expand Down Expand Up @@ -1776,8 +1776,8 @@

if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
Expand Down Expand Up @@ -1818,8 +1818,8 @@

if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
Expand Down Expand Up @@ -1861,7 +1861,7 @@
retval = []
for ii, inp in enumerate(inputs):
if hasattr(inp, "zeros_like"):
retval.append(inp.zeros_like().astype(config.floatX))
retval.append(inp.zeros_like(dtype=config.floatX))
else:
retval.append(grad_undefined(self, ii, inp))
else:
Expand Down Expand Up @@ -1937,7 +1937,7 @@
)

if output_type in discrete_types:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]

for input in inputs:
if gz.type in complex_types:
Expand Down Expand Up @@ -1980,8 +1980,8 @@
raise NotImplementedError()
if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

first_part = gz
Expand Down Expand Up @@ -2036,7 +2036,10 @@
# to the output; x/y is still a function of x
# and y; it's just a step function.
if all(a.dtype in discrete_dtypes for a in (x, y)):
return [x.zeros_like(), y.zeros_like()]
return [

Check warning on line 2039 in pytensor/scalar/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/basic.py#L2039

Added line #L2039 was not covered by tests
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

first_part = gz / y

Expand Down Expand Up @@ -2293,8 +2296,8 @@

if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

first_part = gz * y * x ** (y - 1)
Expand Down Expand Up @@ -2385,7 +2388,7 @@

def handle_int(v):
if outputs[0].type in int_types:
return v.zeros_like().astype(config.floatX)
return v.zeros_like(dtype=config.floatX)
return v

return list(map(handle_int, [gx, gmn, gmx]))
Expand Down Expand Up @@ -2422,7 +2425,7 @@
# to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like()
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)

Check warning on line 2428 in pytensor/scalar/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/basic.py#L2428

Added line #L2428 was not covered by tests


second = Second(transfer_type(1), name="second")
Expand Down Expand Up @@ -2494,7 +2497,7 @@
if self.o_type in continuous_types:
return [gz]
else:
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]

def c_code_cache_version(self):
s = super().c_code_cache_version()
Expand Down Expand Up @@ -2715,7 +2718,7 @@
def grad(self, inputs, gout):
(x,) = inputs
(gz,) = gout
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]

Check warning on line 2721 in pytensor/scalar/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/basic.py#L2721

Added line #L2721 was not covered by tests

def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def grad(self, inp, grads):
# Currently, pytensor.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(config.floatX)]
return [s.zeros_like(dtype=config.floatX)]

raise NotImplementedError("grad not implemented for complex dtypes")

Expand Down Expand Up @@ -1876,7 +1876,7 @@ def infer_shape(self, fgraph, node, ishapes):
def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass
if self.dtype in discrete_dtypes:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]

grads = [output_gradients[0][i] for i in range(len(inputs))]
return grads
Expand Down
8 changes: 2 additions & 6 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor

Expand Down Expand Up @@ -1399,11 +1399,7 @@
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [add(*add_inputs)]
else:
rval = add_inputs
# print "RETURNING GEMM THING", rval
rval = [variadic_add(*add_inputs)]

Check warning on line 1402 in pytensor/tensor/blas.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/blas.py#L1402

Added line #L1402 was not covered by tests
return rval, old_dot22


Expand Down
36 changes: 24 additions & 12 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,18 +1430,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
else:
shp = cast(shp, "float64")

if axis is None:
axis = list(range(input.ndim))
elif isinstance(axis, int | np.integer):
axis = [axis]
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]

# This sequential division will possibly be optimized by PyTensor:
for i in axis:
s = true_div(s, shp[i])
reduced_dims = (
shp
if axis is None
else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)]
)
s /= variadic_mul(*reduced_dims).astype(shp.dtype)

# This can happen when axis is an empty list/tuple
if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
Expand Down Expand Up @@ -1597,6 +1591,15 @@ def add(a, *other_terms):
# see decorator for function body


def variadic_add(*args):
"""Add that accepts arbitrary number of inputs, including zero or one."""
if not args:
return constant(0)
if len(args) == 1:
return args[0]
return add(*args)


@scalar_elemwise
def sub(a, b):
"""elementwise subtraction"""
Expand All @@ -1609,6 +1612,15 @@ def mul(a, *other_terms):
# see decorator for function body


def variadic_mul(*args):
"""Mul that accepts arbitrary number of inputs, including zero or one."""
if not args:
return constant(1)
if len(args) == 1:
return args[0]
return mul(*args)


@scalar_elemwise
def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)"""
Expand Down
13 changes: 4 additions & 9 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq
from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable
Expand Down Expand Up @@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node):
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
return

if len(elements) == 0:
element_sum = zeros(dtype=out_dtype, shape=())
elif len(elements) == 1:
element_sum = cast(elements[0], out_dtype)
else:
element_sum = cast(
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)
element_sum = cast(
variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)

return [element_sum]

Expand Down
15 changes: 10 additions & 5 deletions pytensor/tensor/rewriting/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
from pytensor.tensor.math import (
Dot,
_matrix_matrix_matmul,
add,
mul,
neg,
sub,
variadic_add,
)
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import (
DenseTensorType,
Expand Down Expand Up @@ -386,10 +394,7 @@ def item_to_var(t):
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [add(*add_inputs)]
else:
rval = add_inputs
rval = [variadic_add(*add_inputs)]
# print "RETURNING GEMM THING", rval
return rval, old_dot22

Expand Down
Loading
Loading