diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 2956afad02..b0546b42a4 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -203,6 +203,6 @@ def tri(*args): x if const_x is None else const_x for x, const_x in zip(args, const_args, strict=True) ] - return jnp.tri(*args, dtype=op.dtype) + return jnp.tri(*args) return tri diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e30887cfe3..dc85fbd3b6 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1060,6 +1060,65 @@ def flatnonzero(a): return nonzero(_a.flatten(), return_matrix=False)[0] +def iota(shape: TensorVariable, axis: int) -> TensorVariable: + """ + Create an array with values increasing along the specified axis. + + Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers + increasing along the specified axis. + + Parameters + ---------- + shape: TensorVariable + The shape of the array to be created. + axis: int + The axis along which to fill the array with increasing values. + + Returns + ------- + TensorVariable + An array with values increasing along the specified axis. + + Examples + -------- + In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``: + + .. testcode:: + + import pytensor.tensor as pt + + shape = pt.as_tensor((5,)) + print(pt.basic.iota(shape, 0).eval()) + + .. testoutput:: + + [0 1 2 3 4] + + In higher dimensions, it will look like many concatenated `arange`: + + .. testcode:: + + shape = pt.as_tensor((5, 5)) + print(pt.basic.iota(shape, 1).eval()) + + .. testoutput:: + + [[0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4] + [0 1 2 3 4]] + + Setting ``axis=0`` above would result in the transpose of the output. + """ + len_shape = get_vector_length(shape) + axis = normalize_axis_index(axis, len_shape) + values = arange(shape[axis]) + return pytensor.tensor.extra_ops.broadcast_to( + shape_padright(values, len_shape - axis - 1), shape + ) + + def nonzero_values(a): """Return a vector of non-zero elements contained in the input array. @@ -1084,35 +1143,10 @@ def nonzero_values(a): return _a.flatten()[flatnonzero(_a)] -class Tri(Op): - __props__ = ("dtype",) - - def __init__(self, dtype=None): - if dtype is None: - dtype = config.floatX - self.dtype = dtype - - def make_node(self, N, M, k): - N = as_tensor_variable(N) - M = as_tensor_variable(M) - k = as_tensor_variable(k) - return Apply( - self, - [N, M, k], - [TensorType(dtype=self.dtype, shape=(None, None))()], - ) - - def perform(self, node, inp, out_): - N, M, k = inp - (out,) = out_ - out[0] = np.tri(N, M, k, dtype=self.dtype) - - def infer_shape(self, fgraph, node, in_shapes): - out_shape = [node.inputs[0], node.inputs[1]] - return [out_shape] - - def grad(self, inp, grads): - return [grad_undefined(self, i, inp[i]) for i in range(3)] +class Tri(OpFromGraph): + """ + Wrapper Op for np.tri graphs + """ def tri(N, M=None, k=0, dtype=None): @@ -1142,10 +1176,19 @@ def tri(N, M=None, k=0, dtype=None): """ if dtype is None: dtype = config.floatX + dtype = np.dtype(dtype) + if M is None: M = N - op = Tri(dtype) - return op(N, M, k) + + N = as_tensor_variable(N) + M = as_tensor_variable(M) + k = as_tensor_variable(k) + + output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype( + dtype + ) + return Tri(inputs=[N, M, k], outputs=[output])(N, M, k) def tril(m, k=0): @@ -1225,7 +1268,7 @@ def triu(m, k=0): [ 0, 8, 9], [ 0, 0, 12]]) - >>> pt.triu(np.arange(3 * 4 * 5).reshape((3, 4, 5))).eval() + >>> pt.triu(pt.arange(3 * 4 * 5).reshape((3, 4, 5))).eval() array([[[ 0, 1, 2, 3, 4], [ 0, 6, 7, 8, 9], [ 0, 0, 12, 13, 14], diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 660c16d387..145a3c1bfd 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -11,15 +11,13 @@ from pytensor.npy_2_compat import ( _find_contraction, _parse_einsum_input, - normalize_axis_index, normalize_axis_tuple, ) from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( - arange, as_tensor, expand_dims, - get_vector_length, + iota, moveaxis, stack, transpose, @@ -28,7 +26,6 @@ from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.functional import vectorize from pytensor.tensor.math import and_, eq, tensordot -from pytensor.tensor.shape import shape_padright from pytensor.tensor.variable import TensorVariable @@ -63,64 +60,6 @@ def __str__(self): return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}" -def _iota(shape: TensorVariable, axis: int) -> TensorVariable: - """ - Create an array with values increasing along the specified axis. - - Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers - increasing along the specified axis. - - Parameters - ---------- - shape: TensorVariable - The shape of the array to be created. - axis: int - The axis along which to fill the array with increasing values. - - Returns - ------- - TensorVariable - An array with values increasing along the specified axis. - - Examples - -------- - In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``: - - .. testcode:: - - import pytensor.tensor as pt - from pytensor.tensor.einsum import _iota - - shape = pt.as_tensor((5,)) - print(_iota(shape, 0).eval()) - - .. testoutput:: - - [0 1 2 3 4] - - In higher dimensions, it will look like many concatenated `arange`: - - .. testcode:: - - shape = pt.as_tensor((5, 5)) - print(_iota(shape, 1).eval()) - - .. testoutput:: - - [[0 1 2 3 4] - [0 1 2 3 4] - [0 1 2 3 4] - [0 1 2 3 4] - [0 1 2 3 4]] - - Setting ``axis=0`` above would result in the transpose of the output. - """ - len_shape = get_vector_length(shape) - axis = normalize_axis_index(axis, len_shape) - values = arange(shape[axis]) - return broadcast_to(shape_padright(values, len_shape - axis - 1), shape) - - def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable: """ Create a Kroncker delta tensor. @@ -201,7 +140,7 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable: if len(axes) == 1: raise ValueError("Need at least two axes to create a delta tensor") base_shape = stack([shape[axis] for axis in axes]) - iotas = [_iota(base_shape, i) for i in range(len(axes))] + iotas = [iota(base_shape, i) for i in range(len(axes))] eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)] result = reduce(and_, eyes) non_axes = [i for i in range(len(tuple(shape))) if i not in axes] diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 60643e2984..ba9adb966e 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -58,6 +58,7 @@ identity_like, infer_static_shape, inverse_permutation, + iota, join, make_vector, mgrid, @@ -980,6 +981,29 @@ def test_static_output_type(self): assert eye(1, l, 3).type.shape == (1, None) +def test_iota(): + mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + iota((4, 8), 0).eval(mode=mode), + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2, 2, 2], + [3, 3, 3, 3, 3, 3, 3, 3], + ], + ) + + np.testing.assert_allclose( + iota((4, 8), 1).eval(mode=mode), + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + ], + ) + + class TestTriangle: def test_tri(self): def check(dtype, N, M_=None, k=0): @@ -988,7 +1012,7 @@ def check(dtype, N, M_=None, k=0): M = M_ # Currently DebugMode does not support None as inputs even if this is # allowed. - if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]: + if M is None: # and config.mode in ["DebugMode", "DEBUG_MODE"]: M = N N_symb = iscalar() M_symb = iscalar() @@ -1000,7 +1024,14 @@ def check(dtype, N, M_=None, k=0): assert np.allclose(result, np.tri(N, M_, k, dtype=dtype)) assert result.dtype == np.dtype(dtype) - for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]: + for dtype in [ + "int32", + "int64", + "float32", + "float64", + "uint16", + "complex64", + ]: check(dtype, 3) # M != N, k = 0 check(dtype, 3, 5) @@ -3899,15 +3930,15 @@ def test_Tri(self): biscal = iscalar() ciscal = iscalar() self._compile_and_check( - [aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 4, 0], Tri + [aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [4, 4, 0], Tri ) self._compile_and_check( - [aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 5, 0], Tri + [aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [4, 5, 0], Tri ) self._compile_and_check( - [aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [3, 5, 0], Tri + [aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [3, 5, 0], Tri ) def test_ExtractDiag(self): diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index ba8e354518..8e4e14855c 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -10,7 +10,7 @@ from pytensor.graph.op import HasInnerGraph from pytensor.tensor.basic import moveaxis from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum +from pytensor.tensor.einsum import _delta, _general_dot, einsum from pytensor.tensor.shape import Reshape from pytensor.tensor.type import tensor @@ -38,29 +38,6 @@ def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None: assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op) -def test_iota(): - mode = Mode(linker="py", optimizer=None) - np.testing.assert_allclose( - _iota((4, 8), 0).eval(mode=mode), - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1], - [2, 2, 2, 2, 2, 2, 2, 2], - [3, 3, 3, 3, 3, 3, 3, 3], - ], - ) - - np.testing.assert_allclose( - _iota((4, 8), 1).eval(mode=mode), - [ - [0, 1, 2, 3, 4, 5, 6, 7], - [0, 1, 2, 3, 4, 5, 6, 7], - [0, 1, 2, 3, 4, 5, 6, 7], - [0, 1, 2, 3, 4, 5, 6, 7], - ], - ) - - def test_delta(): mode = Mode(linker="py", optimizer=None) np.testing.assert_allclose(