From 14a99d88656e8c83fef0c982ec6eb90620a01659 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 24 Aug 2024 13:40:40 +0200 Subject: [PATCH 1/2] OpFromGraph subclasses shouldn't have __props__ When specified, Ops with identical __props__ are considered identical, in that they can be swapped and given the original inputs to obtain the same output. --- pytensor/tensor/basic.py | 5 +++-- pytensor/tensor/einsum.py | 5 +++-- tests/tensor/test_basic.py | 13 +++++++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 9eaa04c522..7ae0152b82 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3780,8 +3780,6 @@ class AllocDiag(OpFromGraph): Wrapper Op for alloc_diag graphs """ - __props__ = ("axis1", "axis2") - def __init__(self, *args, axis1, axis2, offset, **kwargs): self.axis1 = axis1 self.axis2 = axis2 @@ -3789,6 +3787,9 @@ def __init__(self, *args, axis1, axis2, offset, **kwargs): super().__init__(*args, **kwargs, strict=True) + def __str__(self): + return f"AllocDiag{{{self.axis1=}, {self.axis2=}, {self.offset=}}}" + @staticmethod def is_offset_zero(node) -> bool: """ diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 79151a91a2..736af9809b 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -52,14 +52,15 @@ class Einsum(OpFromGraph): desired. We haven't decided whether we want to provide this functionality. """ - __props__ = ("subscripts", "path", "optimized") - def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs): self.subscripts = subscripts self.path = path self.optimized = optimized super().__init__(*args, **kwargs, strict=True) + def __str__(self): + return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}" + def _iota(shape: TensorVariable, axis: int) -> TensorVariable: """ diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 58d4de2481..05aa15aa05 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -37,6 +37,7 @@ TensorFromScalar, Tri, alloc, + alloc_diag, arange, as_tensor_variable, atleast_Nd, @@ -3793,6 +3794,18 @@ def test_alloc_diag_values(self): ) assert np.all(true_grad_input == grad_input) + def test_multiple_ops_same_graph(self): + """Regression test when AllocDiag OFG was given insufficient props, causing incompatible Ops to be merged.""" + v1 = vector("v1", shape=(2,), dtype="float64") + v2 = vector("v2", shape=(3,), dtype="float64") + a1 = alloc_diag(v1) + a2 = alloc_diag(v2) + + fn = function([v1, v2], [a1, a2]) + res1, res2 = fn(v1=[np.e, np.e], v2=[np.pi, np.pi, np.pi]) + np.testing.assert_allclose(res1, np.eye(2) * np.e) + np.testing.assert_allclose(res2, np.eye(3) * np.pi) + def test_diagonal_negative_axis(): x = np.arange(2 * 3 * 3).reshape((2, 3, 3)) From fbd1e4df842d67d1303b7f846e994357d45df89b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 24 Aug 2024 12:24:26 +0200 Subject: [PATCH 2/2] Add xfail for numba failing tests reported in https://github.com/pymc-devs/pytensor/issues/980 --- tests/link/numba/test_cython_support.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/link/numba/test_cython_support.py b/tests/link/numba/test_cython_support.py index 65d1947c9d..1613ff638b 100644 --- a/tests/link/numba/test_cython_support.py +++ b/tests/link/numba/test_cython_support.py @@ -76,19 +76,25 @@ def test_signature_provides(have, want, should_provide): [np.float64], float64(float64, int32), ), - ( + pytest.param( # expn doesn't have a float32 implementation scipy.special.cython_special.expn, np.float32, [np.float32, np.float32], float64(float64, float64, int32), + marks=pytest.mark.xfail( + reason="Failing in newer versions: https://github.com/pymc-devs/pytensor/issues/980" + ), ), - ( + pytest.param( # We choose the integer implementation if possible scipy.special.cython_special.expn, np.float32, [np.int64, np.float32], float64(int64, float64, int32), + marks=pytest.mark.xfail( + reason="Failing in newer versions: https://github.com/pymc-devs/pytensor/issues/980" + ), ), ], )