Skip to content

Commit b25ab43

Browse files
committed
Fix same to agree with scipy
1 parent 8443857 commit b25ab43

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

pytensor/signal/conv.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def make_node(self, data, kernel):
3030
elif self.mode == "valid":
3131
out_shape = (max(n, k) - min(n, k) + 1,)
3232
elif self.mode == "same":
33-
out_shape = (max(n, k),)
33+
out_shape = (n,)
3434

3535
out = pt.tensor(dtype=dtype, shape=out_shape)
3636
return Apply(self, [data, kernel], [out])
@@ -48,29 +48,34 @@ def infer_shape(self, fgraph, node, shapes):
4848
elif self.mode == "valid":
4949
shape = pt.maximum(n, k) - pt.minimum(n, k) + 1
5050
elif self.mode == "same":
51-
shape = pt.maximum(n, k)
51+
shape = n
5252
return [[shape]]
5353

5454
def L_op(self, inputs, outputs, output_grads):
5555
data, kernel = inputs
5656
[grad] = output_grads
5757

5858
if self.mode == "full":
59-
valid_conv = type(self)(mode="valid")
60-
data_bar = valid_conv(grad, kernel[::-1])
61-
kernel_bar = valid_conv(grad, data[::-1])
59+
data_bar = convolve(grad, kernel[::-1], mode="valid")
60+
kernel_bar = convolve(grad, data[::-1], mode="valid")
6261

6362
elif self.mode == "valid":
64-
full_conv = type(self)(mode="full")
6563
n = data.shape[0]
6664
k = kernel.shape[0]
6765
kmn = pt.maximum(0, k - n)
6866
nkm = pt.maximum(0, n - k)
6967
# We need mode="full" if k >= n else "valid" for data_bar (opposite for kernel_bar), but mode is not symbolic.
7068
# Instead we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
71-
data_bar = full_conv(grad, kernel[::-1])
69+
data_bar = convolve(grad, kernel[::-1], mode="full")
7270
data_bar = data_bar[kmn : data_bar.shape[0] - kmn]
73-
kernel_bar = full_conv(grad, data[::-1])
71+
kernel_bar = convolve(grad, data[::-1], mode="full")
7472
kernel_bar = kernel_bar[nkm : kernel_bar.shape[0] - nkm]
7573

74+
else: # self.mode == "same"
75+
raise NotImplementedError("L_op not implemented for mode='same'")
76+
7677
return [data_bar, kernel_bar]
78+
79+
80+
def convolve(data, kernel, mode="full"):
81+
return Conv1d(mode)(data, kernel)

tests/signal/test_conv.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1+
from functools import partial
2+
13
import numpy as np
24
import pytest
5+
from scipy.signal import convolve as scipy_convolve
36

4-
from pytensor.signal.conv import Conv1d
7+
from pytensor import function
8+
from pytensor.signal.conv import convolve
9+
from pytensor.tensor import vector
510
from tests import unittest_tools as utt
611

712

8-
@pytest.mark.parametrize("data_shape", [3, 5, 8])
9-
@pytest.mark.parametrize("kernel_shape", [3, 5, 8])
13+
@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}")
14+
@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}")
1015
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
11-
def test_conv1d_grad(mode, data_shape, kernel_shape):
16+
def test_convolve(mode, data_shape, kernel_shape):
1217
rng = np.random.default_rng()
1318

19+
data = vector("data")
20+
kernel = vector("kernel")
21+
op = partial(convolve, mode=mode)
22+
23+
rng = np.random.default_rng()
1424
data_val = rng.normal(size=data_shape)
1525
kernel_val = rng.normal(size=kernel_shape)
1626

17-
op = Conv1d(mode=mode)
18-
19-
utt.verify_grad(op=op, pt=[data_val, kernel_val])
27+
fn = function([data, kernel], op(data, kernel))
28+
np.testing.assert_allclose(
29+
fn(data_val, kernel_val),
30+
scipy_convolve(data_val, kernel_val, mode=mode),
31+
)
32+
utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val])

0 commit comments

Comments
 (0)