diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py new file mode 100644 index 0000000000..6785e479f0 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -0,0 +1,15 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.math import Argmax + + +@pytorch_funcify.register(Argmax) +def pytorch_funcify_Argmax(op, **kwargs): + dim = op.axis + keepdim = op.keepdims + + def argmax(x): + return torch.argmax(x, dim=dim, keepdim=keepdim) + + return argmax diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py new file mode 100644 index 0000000000..8bbbf21056 --- /dev/null +++ b/tests/link/pytorch/test_nlinalg.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest + +from pytensor.configdefaults import config +from pytensor.graph import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.tensor.math import argmax +from pytensor.tensor.type import matrix +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.mark.parametrize( + "keepdims", + [True, False], +) +@pytest.mark.parametrize( + "axis", + [None, 1, (0,)], +) +def test_pytorch_argmax(axis, keepdims): + a = matrix("a", dtype=config.floatX) + a.tag.test_value = np.random.randn(4, 4).astype(config.floatX) + amx = argmax(a, axis=axis, keepdims=keepdims) + fgraph = FunctionGraph([a], amx) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])