Skip to content

Commit 158f223

Browse files
committed
Faster implementation of numba convolve1d
1 parent 2cc864b commit 158f223

File tree

2 files changed

+85
-7
lines changed

2 files changed

+85
-7
lines changed
Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from numba.np.arraymath import _get_inner_prod
23

34
from pytensor.link.numba.dispatch import numba_funcify
45
from pytensor.link.numba.dispatch.basic import numba_njit
@@ -7,10 +8,66 @@
78

89
@numba_funcify.register(Convolve1d)
910
def numba_funcify_Conv1d(op, node, **kwargs):
11+
# This specialized version is faster than the overloaded numba np.convolve,
12+
# as it avoids several runtime checks that don't seem to be inlined.
1013
mode = op.mode
1114

12-
@numba_njit
13-
def conv1d(data, kernel):
14-
return np.convolve(data, kernel, mode=mode)
15+
a_dt = np.dtype(node.inputs[0].dtype)
16+
b_dt = np.dtype(node.inputs[1].dtype)
17+
dt = np.promote_types(a_dt, b_dt)
18+
innerprod = _get_inner_prod(a_dt, b_dt)
1519

16-
return conv1d
20+
if mode == "valid":
21+
22+
def valid_convolve1d(x, y):
23+
nx = len(x)
24+
ny = len(y)
25+
if nx < ny:
26+
x, y = y, x
27+
nx, ny = ny, nx
28+
y_flipped = y[::-1]
29+
30+
length = nx - ny + 1
31+
ret = np.empty(length, dt)
32+
33+
for i in range(length):
34+
ret[i] = innerprod(x[i : i + ny], y_flipped)
35+
36+
return ret
37+
38+
return numba_njit(valid_convolve1d)
39+
40+
elif mode == "full":
41+
42+
def full_convolve1d(x, y):
43+
nx = len(x)
44+
ny = len(y)
45+
if nx < ny:
46+
x, y = y, x
47+
nx, ny = ny, nx
48+
y_flipped = y[::-1]
49+
50+
length = nx + ny - 1
51+
ret = np.empty(length, dt)
52+
idx = 0
53+
54+
for i in range(ny - 1):
55+
k = i + 1
56+
ret[idx] = innerprod(x[:k], y_flipped[-k:])
57+
idx = idx + 1
58+
59+
for i in range(nx - ny + 1):
60+
ret[idx] = innerprod(x[i : i + ny], y_flipped)
61+
idx = idx + 1
62+
63+
for i in range(ny - 1):
64+
k = ny - i - 1
65+
ret[idx] = innerprod(x[-k:], y_flipped[:k])
66+
idx = idx + 1
67+
68+
return ret
69+
70+
return numba_njit(full_convolve1d)
71+
72+
else:
73+
raise ValueError(f"Unsupported mode: {mode}")
Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
import pytest
33

4-
from pytensor.tensor import dmatrix
4+
from pytensor import function
5+
from pytensor.tensor import dmatrix, vector
56
from pytensor.tensor.signal import convolve1d
67
from tests.link.numba.test_basic import compare_numba_and_py
78

@@ -10,13 +11,33 @@
1011

1112

1213
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
13-
def test_convolve1d(mode):
14+
@pytest.mark.parametrize("x_smaller", (False, True))
15+
def test_convolve1d(x_smaller, mode):
1416
x = dmatrix("x")
1517
y = dmatrix("y")
16-
out = convolve1d(x[None], y[:, None], mode=mode)
18+
if x_smaller:
19+
out = convolve1d(x[None], y[:, None], mode=mode)
20+
else:
21+
out = convolve1d(y[:, None], x[None], mode=mode)
1722

1823
rng = np.random.default_rng()
1924
test_x = rng.normal(size=(3, 5))
2025
test_y = rng.normal(size=(7, 11))
2126
# Blockwise dispatch for numba can't be run on object mode
2227
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
28+
29+
30+
@pytest.mark.parametrize("mode", ("full", "valid"))
31+
def test_convolve_benchmark(mode, benchmark):
32+
x = vector(shape=(183,))
33+
y = vector(shape=(6,))
34+
out = convolve1d(x, y, mode=mode)
35+
fn = function([x, y], out, mode="NUMBA", trust_input=True)
36+
37+
rng = np.random.default_rng()
38+
x_test = rng.normal(size=(x.type.shape)).astype(x.type.dtype)
39+
y_test = rng.normal(size=(y.type.shape)).astype(y.type.dtype)
40+
np.testing.assert_allclose(
41+
fn(x_test, y_test), np.convolve(x_test, y_test, mode=mode)
42+
)
43+
benchmark(fn, x_test, y_test)

0 commit comments

Comments
 (0)