Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,6 @@ def tri(*args):
x if const_x is None else const_x
for x, const_x in zip(args, const_args, strict=True)
]
return jnp.tri(*args, dtype=op.dtype)
return jnp.tri(*args)

return tri
107 changes: 75 additions & 32 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,65 @@ def flatnonzero(a):
return nonzero(_a.flatten(), return_matrix=False)[0]


def iota(shape: TensorVariable, axis: int) -> TensorVariable:
"""
Create an array with values increasing along the specified axis.

Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers
increasing along the specified axis.

Parameters
----------
shape: TensorVariable
The shape of the array to be created.
axis: int
The axis along which to fill the array with increasing values.

Returns
-------
TensorVariable
An array with values increasing along the specified axis.

Examples
--------
In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``:

.. testcode::

import pytensor.tensor as pt

shape = pt.as_tensor((5,))
print(pt.basic.iota(shape, 0).eval())

.. testoutput::

[0 1 2 3 4]

In higher dimensions, it will look like many concatenated `arange`:

.. testcode::

shape = pt.as_tensor((5, 5))
print(pt.basic.iota(shape, 1).eval())

.. testoutput::

[[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]]

Setting ``axis=0`` above would result in the transpose of the output.
"""
len_shape = get_vector_length(shape)
axis = normalize_axis_index(axis, len_shape)
values = arange(shape[axis])
return pytensor.tensor.extra_ops.broadcast_to(
shape_padright(values, len_shape - axis - 1), shape
)


def nonzero_values(a):
"""Return a vector of non-zero elements contained in the input array.

Expand All @@ -1084,35 +1143,10 @@ def nonzero_values(a):
return _a.flatten()[flatnonzero(_a)]


class Tri(Op):
__props__ = ("dtype",)

def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
self.dtype = dtype

def make_node(self, N, M, k):
N = as_tensor_variable(N)
M = as_tensor_variable(M)
k = as_tensor_variable(k)
return Apply(
self,
[N, M, k],
[TensorType(dtype=self.dtype, shape=(None, None))()],
)

def perform(self, node, inp, out_):
N, M, k = inp
(out,) = out_
out[0] = np.tri(N, M, k, dtype=self.dtype)

def infer_shape(self, fgraph, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]]
return [out_shape]

def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)]
class Tri(OpFromGraph):
"""
Wrapper Op for np.tri graphs
"""


def tri(N, M=None, k=0, dtype=None):
Expand Down Expand Up @@ -1142,10 +1176,19 @@ def tri(N, M=None, k=0, dtype=None):
"""
if dtype is None:
dtype = config.floatX
dtype = np.dtype(dtype)

if M is None:
M = N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
M = N
M = N.copy()

OpFromGraph can freak out sometimes if a single input is re-used

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski Makes sense, I'll make the change. But it still resulted in the same error.

Stack trace (formatting may be unclear)

tests\link\jax\test_tensor_basic.py F [100%]

====================================================== FAILURES =======================================================
______________________________________________________ test_tri _______________________________________________________

def streamline_default_f():
    for x in no_recycling:
        x[0] = None
    try:
        # strict=False because we are in a hot loop
        for thunk, node, old_storage in zip(
            thunks, order, post_thunk_old_storage, strict=False
        ):
          thunk()

pytensor\link\utils.py:197:


pytensor\graph\op.py:531: in rval
r = p(n, [x[0] for x in i], o)
pytensor\compile\builders.py:875: in perform
variables = self.fn(*inputs)
pytensor\compile\builders.py:856: in fn
self.fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
pytensor\compile\function_init
.py:332: in function
fn = pfunc(
pytensor\compile\function\pfunc.py:472: in pfunc
return orig_function(
pytensor\compile\function\types.py:1820: in orig_function
m = Maker(
pytensor\compile\function\types.py:1567: in init
self.check_unused_inputs(inputs, outputs, on_unused_input)


inputs = [In(*0-<Scalar(int8, shape=())>), In(*1-<Scalar(int8, shape=())>), In(*2-<Scalar(int8, shape=())>)]
outputs = [Out(Cast{float64}.0,False)], on_unused_input = 'raise'

@staticmethod
def check_unused_inputs(inputs, outputs, on_unused_input):
    if on_unused_input is None:
        on_unused_input = config.on_unused_input

    if on_unused_input == "ignore":
        return

    # There should be two categories of variables in inputs:
    #  - variables that have to be provided (used_inputs)
    #  - shared variables that will be updated
    used_inputs = list(
        ancestors(
            (
                [o.variable for o in outputs]
                + [
                    i.update
                    for i in inputs
                    if getattr(i, "update", None) is not None
                ]
            ),
            blockers=[i.variable for i in inputs],
        )
    )

    msg = (
        "pytensor.function was asked to create a function computing "
        "outputs given certain inputs, but the provided input "
        "variable at index %i is not part of the computational graph "
        "needed to compute the outputs: %s.\n%s"
    )
    warn_msg = (
        "To make this warning into an error, you can pass the "
        "parameter on_unused_input='raise' to pytensor.function. "
        "To disable it completely, use on_unused_input='ignore'."
    )
    err_msg = (
        "To make this error into a warning, you can pass the "
        "parameter on_unused_input='warn' to pytensor.function. "
        "To disable it completely, use on_unused_input='ignore'."
    )

    for i in inputs:
        if (i.variable not in used_inputs) and (i.update is None):
            if on_unused_input == "warn":
                warnings.warn(
                    msg % (inputs.index(i), i.variable, warn_msg), stacklevel=6
                )
            elif on_unused_input == "raise":
              raise UnusedInputError(msg % (inputs.index(i), i.variable, err_msg))

E pytensor.compile.function.types.UnusedInputError: pytensor.function was asked to create a function computing outputs given certain inputs, but the provided input variable at index 0 is not part of the computational graph needed to compute the outputs: *0-<Scalar(int8, shape=())>.
E To make this error into a warning, you can pass the parameter on_unused_input='warn' to pytensor.function. To disable it completely, use on_unused_input='ignore'.

pytensor\compile\function\types.py:1438: UnusedInputError

During handling of the above exception, another exception occurred:

def test_tri():
    out = ptb.tri(10, 10, 0)
  compare_jax_and_py([], [out], [])

tests\link\jax\test_tensor_basic.py:207:


tests\link\jax\test_basic.py:87: in compare_jax_and_py
py_res = pytensor_py_fn(*test_inputs)
pytensor\compile\function\types.py:1037: in call
outputs = vm() if output_subset is None else vm(output_subset=output_subset)
pytensor\link\utils.py:201: in streamline_default_f
raise_with_op(fgraph, node, thunk)
pytensor\link\utils.py:526: in raise_with_op
raise exc_value.with_traceback(exc_trace)
pytensor\link\utils.py:197: in streamline_default_f
thunk()
pytensor\graph\op.py:531: in rval
r = p(n, [x[0] for x in i], o)
pytensor\compile\builders.py:875: in perform
variables = self.fn(*inputs)
pytensor\compile\builders.py:856: in fn
self.fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
pytensor\compile\function_init
.py:332: in function
fn = pfunc(
pytensor\compile\function\pfunc.py:472: in pfunc
return orig_function(
pytensor\compile\function\types.py:1820: in orig_function
m = Maker(
pytensor\compile\function\types.py:1567: in init
self.check_unused_inputs(inputs, outputs, on_unused_input)


inputs = [In(*0-<Scalar(int8, shape=())>), In(*1-<Scalar(int8, shape=())>), In(*2-<Scalar(int8, shape=())>)]
outputs = [Out(Cast{float64}.0,False)], on_unused_input = 'raise'

@staticmethod
def check_unused_inputs(inputs, outputs, on_unused_input):
    if on_unused_input is None:
        on_unused_input = config.on_unused_input

    if on_unused_input == "ignore":
        return

    # There should be two categories of variables in inputs:
    #  - variables that have to be provided (used_inputs)
    #  - shared variables that will be updated
    used_inputs = list(
        ancestors(
            (
                [o.variable for o in outputs]
                + [
                    i.update
                    for i in inputs
                    if getattr(i, "update", None) is not None
                ]
            ),
            blockers=[i.variable for i in inputs],
        )
    )

    msg = (
        "pytensor.function was asked to create a function computing "
        "outputs given certain inputs, but the provided input "
        "variable at index %i is not part of the computational graph "
        "needed to compute the outputs: %s.\n%s"
    )
    warn_msg = (
        "To make this warning into an error, you can pass the "
        "parameter on_unused_input='raise' to pytensor.function. "
        "To disable it completely, use on_unused_input='ignore'."
    )
    err_msg = (
        "To make this error into a warning, you can pass the "
        "parameter on_unused_input='warn' to pytensor.function. "
        "To disable it completely, use on_unused_input='ignore'."
    )

    for i in inputs:
        if (i.variable not in used_inputs) and (i.update is None):
            if on_unused_input == "warn":
                warnings.warn(
                    msg % (inputs.index(i), i.variable, warn_msg), stacklevel=6
                )
            elif on_unused_input == "raise":
              raise UnusedInputError(msg % (inputs.index(i), i.variable, err_msg))

E pytensor.compile.function.types.UnusedInputError: pytensor.function was asked to create a function computing outputs given certain inputs, but the provided input variable at index 0 is not part of the computational graph needed to compute the outputs: *0-<Scalar(int8, shape=())>.
E To make this error into a warning, you can pass the parameter on_unused_input='warn' to pytensor.function. To disable it completely, use on_unused_input='ignore'.
E Apply node that caused the error: Tri{inline=False}(10, 10, 0)
E Toposort index: 0
E Inputs types: [TensorType(int8, shape=()), TensorType(int8, shape=()), TensorType(int8, shape=())]
E Inputs shapes: [(), (), ()]
E Inputs strides: [(), (), ()]
E Inputs values: [array(10, dtype=int8), array(10, dtype=int8), array(0, dtype=int8)]
E Outputs clients: [[output0]]
E
E Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
E File "C:\Users\Public\miniforge3\envs\pytensor-dev\Lib\site-packages\pluggy_callers.py", line 103, in _multicall
E res = hook_impl.function(*args)
E File "C:\Users\Public\miniforge3\envs\pytensor-dev\Lib\site-packages_pytest\runner.py", line 174, in pytest_runtest_call
E item.runtest()
E File "C:\Users\Public\miniforge3\envs\pytensor-dev\Lib\site-packages_pytest\python.py", line 1627, in runtest
E self.ihook.pytest_pyfunc_call(pyfuncitem=self)
E File "C:\Users\Public\miniforge3\envs\pytensor-dev\Lib\site-packages\pluggy_hooks.py", line 513, in call
E return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
E File "C:\Users\Public\miniforge3\envs\pytensor-dev\Lib\site-packages\pluggy_manager.py", line 120, in _hookexec
E return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E File "C:\Users\Public\miniforge3\envs\pytensor-dev\Lib\site-packages\pluggy_callers.py", line 103, in _multicall
E res = hook_impl.function(*args)
E File "C:\Users\Public\miniforge3\envs\pytensor-dev\Lib\site-packages_pytest\python.py", line 159, in pytest_pyfunc_call
E result = testfunction(**testargs)
E File "C:\Users\Nimish Purohit\pytensor\tests\link\jax\test_tensor_basic.py", line 205, in test_tri
E out = ptb.tri(10, 10, 0)
E
E HINT: Use the PyTensor flag exception_verbosity=high for a debug print-out and storage map footprint of this Apply node.

pytensor\compile\function\types.py:1438: UnusedInputError
================================================ slowest 50 durations =================================================
0.22s call tests/link/jax/test_tensor_basic.py::test_tri

(2 durations < 0.005s hidden. Use -vv to show these durations.)
=============================================== short test summary info ===============================================
FAILED tests/link/jax/test_tensor_basic.py::test_tri - pytensor.compile.function.types.UnusedInputError: pytensor.function was asked to create a function computing output...
================================================== 1 failed in 5.51s ==================================================

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look more closely at what is going on over the weekend. It's not obvious to me at first glance

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski Any idea on how to proceed? The test function call will result in this call, which is causing the UnusedInputError. This call will contain an empty graph_inputs, compared to the normal python test cases that are passing.

op = Tri(dtype)
return op(N, M, k)

N = as_tensor_variable(N)
M = as_tensor_variable(M)
k = as_tensor_variable(k)

output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype(
dtype
)
return Tri(inputs=[N, M, k], outputs=[output])(N, M, k)


def tril(m, k=0):
Expand Down Expand Up @@ -1225,7 +1268,7 @@ def triu(m, k=0):
[ 0, 8, 9],
[ 0, 0, 12]])

>>> pt.triu(np.arange(3 * 4 * 5).reshape((3, 4, 5))).eval()
>>> pt.triu(pt.arange(3 * 4 * 5).reshape((3, 4, 5))).eval()
array([[[ 0, 1, 2, 3, 4],
[ 0, 6, 7, 8, 9],
[ 0, 0, 12, 13, 14],
Expand Down
65 changes: 2 additions & 63 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
from pytensor.npy_2_compat import (
_find_contraction,
_parse_einsum_input,
normalize_axis_index,
normalize_axis_tuple,
)
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import (
arange,
as_tensor,
expand_dims,
get_vector_length,
iota,
moveaxis,
stack,
transpose,
Expand All @@ -28,7 +26,6 @@
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.functional import vectorize
from pytensor.tensor.math import and_, eq, tensordot
from pytensor.tensor.shape import shape_padright
from pytensor.tensor.variable import TensorVariable


Expand Down Expand Up @@ -63,64 +60,6 @@ def __str__(self):
return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}"


def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
"""
Create an array with values increasing along the specified axis.

Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers
increasing along the specified axis.

Parameters
----------
shape: TensorVariable
The shape of the array to be created.
axis: int
The axis along which to fill the array with increasing values.

Returns
-------
TensorVariable
An array with values increasing along the specified axis.

Examples
--------
In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``:

.. testcode::

import pytensor.tensor as pt
from pytensor.tensor.einsum import _iota

shape = pt.as_tensor((5,))
print(_iota(shape, 0).eval())

.. testoutput::

[0 1 2 3 4]

In higher dimensions, it will look like many concatenated `arange`:

.. testcode::

shape = pt.as_tensor((5, 5))
print(_iota(shape, 1).eval())

.. testoutput::

[[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]
[0 1 2 3 4]]

Setting ``axis=0`` above would result in the transpose of the output.
"""
len_shape = get_vector_length(shape)
axis = normalize_axis_index(axis, len_shape)
values = arange(shape[axis])
return broadcast_to(shape_padright(values, len_shape - axis - 1), shape)


def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
"""
Create a Kroncker delta tensor.
Expand Down Expand Up @@ -201,7 +140,7 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
if len(axes) == 1:
raise ValueError("Need at least two axes to create a delta tensor")
base_shape = stack([shape[axis] for axis in axes])
iotas = [_iota(base_shape, i) for i in range(len(axes))]
iotas = [iota(base_shape, i) for i in range(len(axes))]
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
result = reduce(and_, eyes)
non_axes = [i for i in range(len(tuple(shape))) if i not in axes]
Expand Down
41 changes: 36 additions & 5 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
identity_like,
infer_static_shape,
inverse_permutation,
iota,
join,
make_vector,
mgrid,
Expand Down Expand Up @@ -980,6 +981,29 @@ def test_static_output_type(self):
assert eye(1, l, 3).type.shape == (1, None)


def test_iota():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
iota((4, 8), 0).eval(mode=mode),
[
[0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3],
],
)

np.testing.assert_allclose(
iota((4, 8), 1).eval(mode=mode),
[
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
],
)


class TestTriangle:
def test_tri(self):
def check(dtype, N, M_=None, k=0):
Expand All @@ -988,7 +1012,7 @@ def check(dtype, N, M_=None, k=0):
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
if M is None: # and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N
N_symb = iscalar()
M_symb = iscalar()
Expand All @@ -1000,7 +1024,14 @@ def check(dtype, N, M_=None, k=0):
assert np.allclose(result, np.tri(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)

for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
for dtype in [
"int32",
"int64",
"float32",
"float64",
"uint16",
"complex64",
]:
check(dtype, 3)
# M != N, k = 0
check(dtype, 3, 5)
Expand Down Expand Up @@ -3899,15 +3930,15 @@ def test_Tri(self):
biscal = iscalar()
ciscal = iscalar()
self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 4, 0], Tri
[aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [4, 4, 0], Tri
)

self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 5, 0], Tri
[aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [4, 5, 0], Tri
)

self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [3, 5, 0], Tri
[aiscal, biscal, ciscal], [tri(aiscal, biscal, ciscal)], [3, 5, 0], Tri
)

def test_ExtractDiag(self):
Expand Down
25 changes: 1 addition & 24 deletions tests/tensor/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.basic import moveaxis
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
from pytensor.tensor.einsum import _delta, _general_dot, einsum
from pytensor.tensor.shape import Reshape
from pytensor.tensor.type import tensor

Expand Down Expand Up @@ -38,29 +38,6 @@ def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None:
assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op)


def test_iota():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
_iota((4, 8), 0).eval(mode=mode),
[
[0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3],
],
)

np.testing.assert_allclose(
_iota((4, 8), 1).eval(mode=mode),
[
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
],
)


def test_delta():
mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
Expand Down
Loading