11from typing import TYPE_CHECKING , Literal , cast
22
3+ import numpy as np
34from numpy import convolve as numpy_convolve
45
5- from pytensor .graph import Apply
6+ from pytensor .gradient import DisconnectedType
7+ from pytensor .graph import Apply , Constant
68from pytensor .link .c .op import COp
9+ from pytensor .scalar import as_scalar
710from pytensor .scalar .basic import upcast
811from pytensor .tensor .basic import as_tensor_variable , join , zeros
912from pytensor .tensor .blockwise import Blockwise
10- from pytensor .tensor .math import maximum , minimum
13+ from pytensor .tensor .math import maximum , minimum , switch
1114from pytensor .tensor .type import vector
1215from pytensor .tensor .variable import TensorVariable
1316
1720
1821
1922class Convolve1d (COp ):
20- __props__ = ("mode" , )
21- gufunc_signature = "(n),(k)->(o)"
23+ __props__ = ()
24+ gufunc_signature = "(n),(k),() ->(o)"
2225
23- def __init__ (self , mode : Literal ["full" , "valid" ] = "full" ):
24- if mode not in ("full" , "valid" ):
25- raise ValueError (f"Invalid mode: { mode } " )
26- self .mode = mode
27-
28- def make_node (self , in1 , in2 ):
26+ def make_node (self , in1 , in2 , is_full_mode ):
2927 in1 = as_tensor_variable (in1 )
3028 in2 = as_tensor_variable (in2 )
29+ is_full_mode = as_scalar (is_full_mode )
3130
32- assert in1 .ndim == 1
33- assert in2 .ndim == 1
31+ if not (in1 .ndim == 1 and in2 .ndim == 1 ):
32+ raise ValueError ("Convolution inputs must be vector (ndim=1)" )
33+ if not is_full_mode .dtype == "bool" :
34+ raise ValueError ("Convolution mode must be a boolean type" )
3435
3536 dtype = upcast (in1 .dtype , in2 .dtype )
36-
3737 n = in1 .type .shape [0 ]
3838 k = in2 .type .shape [0 ]
39+ match is_full_mode :
40+ case Constant ():
41+ static_mode = "full" if is_full_mode .data else "valid"
42+ case _:
43+ static_mode = None
3944
40- if n is None or k is None :
45+ if n is None or k is None or static_mode is None :
4146 out_shape = (None ,)
42- elif self . mode == "full" :
47+ elif static_mode == "full" :
4348 out_shape = (n + k - 1 ,)
4449 else : # mode == "valid":
4550 out_shape = (max (n , k ) - min (n , k ) + 1 ,)
4651
4752 out = vector (dtype = dtype , shape = out_shape )
48- return Apply (self , [in1 , in2 ], [out ])
53+ return Apply (self , [in1 , in2 , is_full_mode ], [out ])
4954
5055 def perform (self , node , inputs , outputs ):
5156 # We use numpy_convolve as that's what scipy would use if method="direct" was passed.
5257 # And mode != "same", which this Op doesn't cover anyway.
53- outputs [0 ][0 ] = numpy_convolve (* inputs , mode = self .mode )
58+ in1 , in2 , is_full_mode = inputs
59+ outputs [0 ][0 ] = numpy_convolve (
60+ in1 , in2 , mode = "full" if is_full_mode else "valid"
61+ )
5462
5563 def infer_shape (self , fgraph , node , shapes ):
56- in1_shape , in2_shape = shapes
64+ _ , _ , is_full_mode = node .inputs
65+ in1_shape , in2_shape , _ = shapes
5766 n = in1_shape [0 ]
5867 k = in2_shape [0 ]
59- if self .mode == "full" :
60- shape = n + k - 1
61- else : # mode == "valid":
62- shape = maximum (n , k ) - minimum (n , k ) + 1
68+ shape_valid = maximum (n , k ) - minimum (n , k ) + 1
69+ shape_full = n + k - 1
70+ shape = switch (is_full_mode , shape_full , shape_valid )
6371 return [[shape ]]
6472
73+ def connection_pattern (self , node ):
74+ return [[True ], [True ], [False ]]
75+
6576 def L_op (self , inputs , outputs , output_grads ):
66- in1 , in2 = inputs
77+ in1 , in2 , is_full_mode = inputs
6778 [grad ] = output_grads
6879
69- if self .mode == "full" :
70- valid_conv = type (self )(mode = "valid" )
71- in1_bar = valid_conv (grad , in2 [::- 1 ])
72- in2_bar = valid_conv (grad , in1 [::- 1 ])
80+ n = in1 .shape [0 ]
81+ k = in2 .shape [0 ]
7382
74- else : # mode == "valid":
75- full_conv = type (self )(mode = "full" )
76- n = in1 .shape [0 ]
77- k = in2 .shape [0 ]
78- kmn = maximum (0 , k - n )
79- nmk = maximum (0 , n - k )
80- # We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
81- # Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
82- # There is a rewrite that optimizes this case when n, k are static
83- in1_bar = full_conv (grad , in2 [::- 1 ])
84- in1_bar = in1_bar [kmn : in1_bar .shape [0 ] - kmn ]
85- in2_bar = full_conv (grad , in1 [::- 1 ])
86- in2_bar = in2_bar [nmk : in2_bar .shape [0 ] - nmk ]
87-
88- return [in1_bar , in2_bar ]
83+ # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
84+ # The expression below is equivalent to switch(mode | (k >= n), False, True)
85+ mode_in1_bar = ~ is_full_mode & (n > k )
86+ # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
87+ # The expression below is equivalent to switch(mode | (n >= k), False, True)
88+ mode_in2_bar = ~ is_full_mode & (k > n )
89+
90+ return [
91+ self (grad , in2 [::- 1 ], mode_in1_bar ),
92+ self (grad , in1 [::- 1 ], mode_in2_bar ),
93+ DisconnectedType ()(),
94+ ]
8995
9096 def c_code_cache_version (self ):
91- return ( 1 ,)
97+ return None # (2 ,)
9298
9399 def c_code (self , node , name , inputs , outputs , sub ):
94- # raise NotImplementedError()
95- in1 , in2 = inputs
100+ in1 , in2 , is_full_mode = inputs
96101 [out ] = outputs
97- mode_str = self .mode
98-
99- if mode_str == "full" :
100- np_mode_val = 2 # NPY_CONVOLVE_FULL
101- elif mode_str == "valid" :
102- np_mode_val = 0 # NPY_CONVOLVE_VALID
103- else :
104- # This case should ideally be prevented by __init__ or make_node
105- raise ValueError (f"Unsupported mode { mode_str } " )
106102
107103 code = f"""
108104 {{
@@ -158,7 +154,7 @@ def c_code(self, node, name, inputs, outputs, sub):
158154
159155 // TODO: Use lower level implementation that allows reusing the output buffer
160156 Py_XDECREF({ out } );
161- { out } = (PyArrayObject*) PyArray_Correlate2((PyObject*){ in1 } , (PyObject*)in2_flipped_view, { np_mode_val } );
157+ { out } = (PyArrayObject*) PyArray_Correlate2((PyObject*){ in1 } , (PyObject*)in2_flipped_view, { is_full_mode } ? 2 : 0 );
162158 Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
163159 if (!{ out } ) {{
164160 // PyArray_Correlate already set an error
@@ -169,6 +165,9 @@ def c_code(self, node, name, inputs, outputs, sub):
169165 return code
170166
171167
168+ blockwise_convolve_1d = Blockwise (Convolve1d ())
169+
170+
172171def convolve1d (
173172 in1 : "TensorLike" ,
174173 in2 : "TensorLike" ,
@@ -212,4 +211,5 @@ def convolve1d(
212211 )
213212 mode = "valid"
214213
215- return cast (TensorVariable , Blockwise (Convolve1d (mode = mode ))(in1 , in2 ))
214+ is_full_mode = as_scalar (np .bool_ (mode == "full" ))
215+ return cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , is_full_mode ))
0 commit comments