Skip to content

Commit 1d82fb4

Browse files
committed
Make convolve mode symbolic to avoid unnecessary large convolution in gradient
1 parent a62e785 commit 1d82fb4

File tree

8 files changed

+168
-213
lines changed

8 files changed

+168
-213
lines changed
Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
import jax
22

33
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.basic import get_underlying_scalar_constant_value
5+
from pytensor.tensor.exceptions import NotScalarConstantError
46
from pytensor.tensor.signal.conv import Convolve1d
57

68

79
@jax_funcify.register(Convolve1d)
810
def jax_funcify_Convolve1d(op, node, **kwargs):
9-
mode = op.mode
11+
_, _, full_mode = node.inputs
12+
try:
13+
full_mode = get_underlying_scalar_constant_value(full_mode)
14+
except NotScalarConstantError:
15+
raise NotImplementedError(
16+
"Cannot compile Convolve1D to jax without static mode"
17+
)
18+
static_mode = "full" if full_mode else "valid"
1019

11-
def conv1d(data, kernel):
12-
return jax.numpy.convolve(data, kernel, mode=mode)
20+
def conv1d(data, kernel, _runtime_full_mode):
21+
# _runtime_full_mode is not used, as we only support static mode
22+
return jax.numpy.convolve(data, kernel, mode=static_mode)
1323

1424
return conv1d

pytensor/link/numba/dispatch/signal/conv.py

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,62 +9,61 @@
99
@numba_funcify.register(Convolve1d)
1010
def numba_funcify_Convolve1d(op, node, **kwargs):
1111
# This specialized version is faster than the overloaded numba np.convolve
12-
mode = op.mode
1312
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
1413
out_dtype = node.outputs[0].type.dtype
1514
innerprod = _get_inner_prod(a_dtype, b_dtype)
1615

17-
if mode == "valid":
18-
19-
def valid_convolve1d(x, y):
20-
nx = len(x)
21-
ny = len(y)
22-
if nx < ny:
23-
x, y = y, x
24-
nx, ny = ny, nx
25-
y_flipped = y[::-1]
26-
27-
length = nx - ny + 1
28-
ret = np.empty(length, out_dtype)
29-
30-
for i in range(length):
31-
ret[i] = innerprod(x[i : i + ny], y_flipped)
32-
33-
return ret
34-
35-
return numba_njit(valid_convolve1d)
36-
37-
elif mode == "full":
38-
39-
def full_convolve1d(x, y):
40-
nx = len(x)
41-
ny = len(y)
42-
if nx < ny:
43-
x, y = y, x
44-
nx, ny = ny, nx
45-
y_flipped = y[::-1]
46-
47-
length = nx + ny - 1
48-
ret = np.empty(length, out_dtype)
49-
idx = 0
50-
51-
for i in range(ny - 1):
52-
k = i + 1
53-
ret[idx] = innerprod(x[:k], y_flipped[-k:])
54-
idx = idx + 1
55-
56-
for i in range(nx - ny + 1):
57-
ret[idx] = innerprod(x[i : i + ny], y_flipped)
58-
idx = idx + 1
59-
60-
for i in range(ny - 1):
61-
k = ny - i - 1
62-
ret[idx] = innerprod(x[-k:], y_flipped[:k])
63-
idx = idx + 1
64-
65-
return ret
66-
67-
return numba_njit(full_convolve1d)
68-
69-
else:
70-
raise ValueError(f"Unsupported mode: {mode}")
16+
@numba_njit
17+
def valid_convolve1d(x, y):
18+
nx = len(x)
19+
ny = len(y)
20+
if nx < ny:
21+
x, y = y, x
22+
nx, ny = ny, nx
23+
y_flipped = y[::-1]
24+
25+
length = nx - ny + 1
26+
ret = np.empty(length, out_dtype)
27+
28+
for i in range(length):
29+
ret[i] = innerprod(x[i : i + ny], y_flipped)
30+
31+
return ret
32+
33+
@numba_njit
34+
def full_convolve1d(x, y):
35+
nx = len(x)
36+
ny = len(y)
37+
if nx < ny:
38+
x, y = y, x
39+
nx, ny = ny, nx
40+
y_flipped = y[::-1]
41+
42+
length = nx + ny - 1
43+
ret = np.empty(length, out_dtype)
44+
idx = 0
45+
46+
for i in range(ny - 1):
47+
k = i + 1
48+
ret[idx] = innerprod(x[:k], y_flipped[-k:])
49+
idx = idx + 1
50+
51+
for i in range(nx - ny + 1):
52+
ret[idx] = innerprod(x[i : i + ny], y_flipped)
53+
idx = idx + 1
54+
55+
for i in range(ny - 1):
56+
k = ny - i - 1
57+
ret[idx] = innerprod(x[-k:], y_flipped[:k])
58+
idx = idx + 1
59+
60+
return ret
61+
62+
@numba_njit
63+
def convolve_1d(x, y, mode):
64+
if mode:
65+
return full_convolve1d(x, y)
66+
else:
67+
return valid_convolve1d(x, y)
68+
69+
return convolve_1d

pytensor/tensor/blockwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,12 @@ def extract_core_shape_from_infer_shape():
360360
dummy_fgraph, dummy_core_node, core_input_shapes
361361
)
362362

363-
# Set to None those core_shapes that depend on dummy_core_inputs,
364-
# meaning their value may not be constant across batch dims of the Blockwise
365363
if not dummy_core_inputs:
366364
# All inputs are unbatched, so the core_shape can be used as is
367365
return core_output_shapes
368366
else:
367+
# Set to None those core_shapes that depend on dummy_core_inputs,
368+
# meaning their value may not be constant across batch dims of the Blockwise
369369
set_dummy_core_inputs = set(dummy_core_inputs)
370370
safe_core_output_shapes = [list(shape) for shape in core_output_shapes]
371371
for core_out_shape in safe_core_output_shapes:

pytensor/tensor/rewriting/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytensor.tensor.rewriting.blas_c
44
import pytensor.tensor.rewriting.blas_scipy
55
import pytensor.tensor.rewriting.blockwise
6-
import pytensor.tensor.rewriting.conv
76
import pytensor.tensor.rewriting.einsum
87
import pytensor.tensor.rewriting.elemwise
98
import pytensor.tensor.rewriting.extra_ops

pytensor/tensor/rewriting/conv.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)