Skip to content

Commit 547eaa5

Browse files
committed
Attempt to add support for scipy.signal.gauss_spline
1 parent bf628c9 commit 547eaa5

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

pytensor/tensor/ssignal.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from pytensor.tensor import Op, as_tensor_variable
2+
from pytensor.graph.basic import Apply
3+
import pytensor.tensor as pt
4+
from pytensor.tensor.type import TensorType
5+
import scipy.signal as scipy_signal
6+
7+
8+
class GaussSpline(Op):
9+
__props__ = ("n",)
10+
11+
def __init__(self, n: int = None):
12+
self.n = n
13+
14+
def make_node(self, knots):
15+
knots = as_tensor_variable(knots)
16+
if not isinstance(knots.type, TensorType):
17+
raise TypeError("Input must be a TensorType")
18+
19+
if not isinstance(self.n, int) or self.n is None or self.n < 0:
20+
raise ValueError("n must be a non-negative integer")
21+
22+
if knots.ndim < 1:
23+
raise TypeError("Input must be at least 1-dimensional")
24+
25+
out = knots.type()
26+
return Apply(self, [knots], [out])
27+
28+
def perform(self, node, inputs, output_storage):
29+
[x] = inputs
30+
[out] = output_storage
31+
out[0] = scipy_signal.gauss_spline(x, self.n)
32+
33+
def infer_shape(self, fgraph, node, shapes):
34+
return [shapes[0]]
35+
36+
37+
def gauss_spline(x, n):
38+
return GaussSpline(n)(x)

tests/tensor/test_ssignal.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pytensor.tensor.ssignal import GaussSpline, gauss_spline
2+
from pytensor.tensor.type import matrix
3+
from pytensor import function
4+
from pytensor import tensor as pt
5+
import numpy as np
6+
import pytest
7+
from tests import unittest_tools as utt
8+
9+
import scipy.signal as scipy_signal
10+
11+
class TestGaussSpline(utt.InferShapeTester):
12+
13+
def setup_method(self):
14+
super().setup_method()
15+
self.op_class = GaussSpline
16+
self.op = gauss_spline
17+
18+
@pytest.mark.parametrize("n", [-1, 1.5, None, "string"])
19+
def test_make_node_raises(self, n):
20+
a = matrix()
21+
with pytest.raises(ValueError, match="n must be a non-negative integer"):
22+
self.op(a, n=n)
23+
24+
def test_perform(self):
25+
a = matrix()
26+
f = function([a], self.op(a, n=10))
27+
a = np.random.random((8, 6))
28+
assert np.allclose(f(a), scipy_signal.gauss_spline(a, 10))
29+
30+
def test_infer_shape(self):
31+
a = matrix()
32+
self._compile_and_check(
33+
[a], [self.op(a, 16)], [np.random.random((12, 4))], self.op_class
34+
)

0 commit comments

Comments
 (0)