11from typing import TYPE_CHECKING , Literal , cast
22
3+ import numpy as np
34from numpy import convolve as numpy_convolve
45
5- from pytensor .graph import Apply
6+ from pytensor .gradient import DisconnectedType
7+ from pytensor .graph import Apply , Constant
68from pytensor .link .c .op import COp
9+ from pytensor .scalar import as_scalar
710from pytensor .scalar .basic import upcast
811from pytensor .tensor .basic import as_tensor_variable , join , zeros
912from pytensor .tensor .blockwise import Blockwise
10- from pytensor .tensor .math import maximum , minimum
13+ from pytensor .tensor .math import maximum , minimum , switch
1114from pytensor .tensor .type import vector
1215from pytensor .tensor .variable import TensorVariable
1316
1720
1821
1922class Convolve1d (COp ):
20- __props__ = ("mode" , )
21- gufunc_signature = "(n),(k)->(o)"
23+ __props__ = ()
24+ gufunc_signature = "(n),(k),() ->(o)"
2225
23- def __init__ (self , mode : Literal ["full" , "valid" ] = "full" ):
24- if mode not in ("full" , "valid" ):
25- raise ValueError (f"Invalid mode: { mode } " )
26- self .mode = mode
27-
28- def make_node (self , in1 , in2 ):
26+ def make_node (self , in1 , in2 , full_mode ):
2927 in1 = as_tensor_variable (in1 )
3028 in2 = as_tensor_variable (in2 )
29+ full_mode = as_scalar (full_mode )
3130
32- assert in1 .ndim == 1
33- assert in2 .ndim == 1
31+ if not (in1 .ndim == 1 and in2 .ndim == 1 ):
32+ raise ValueError ("Convolution inputs must be vector (ndim=1)" )
33+ if not full_mode .dtype == "bool" :
34+ raise ValueError ("Convolution mode must be a boolean type" )
3435
3536 dtype = upcast (in1 .dtype , in2 .dtype )
36-
3737 n = in1 .type .shape [0 ]
3838 k = in2 .type .shape [0 ]
39+ match full_mode :
40+ case Constant ():
41+ static_mode = "full" if full_mode .data else "valid"
42+ case _:
43+ static_mode = None
3944
40- if n is None or k is None :
45+ if n is None or k is None or static_mode is None :
4146 out_shape = (None ,)
42- elif self . mode == "full" :
47+ elif static_mode == "full" :
4348 out_shape = (n + k - 1 ,)
4449 else : # mode == "valid":
4550 out_shape = (max (n , k ) - min (n , k ) + 1 ,)
4651
4752 out = vector (dtype = dtype , shape = out_shape )
48- return Apply (self , [in1 , in2 ], [out ])
53+ return Apply (self , [in1 , in2 , full_mode ], [out ])
4954
5055 def perform (self , node , inputs , outputs ):
5156 # We use numpy_convolve as that's what scipy would use if method="direct" was passed.
5257 # And mode != "same", which this Op doesn't cover anyway.
53- outputs [0 ][0 ] = numpy_convolve (* inputs , mode = self .mode )
58+ in1 , in2 , full_mode = inputs
59+ outputs [0 ][0 ] = numpy_convolve (in1 , in2 , mode = "full" if full_mode else "valid" )
5460
5561 def infer_shape (self , fgraph , node , shapes ):
56- in1_shape , in2_shape = shapes
62+ _ , _ , full_mode = node .inputs
63+ in1_shape , in2_shape , _ = shapes
5764 n = in1_shape [0 ]
5865 k = in2_shape [0 ]
59- if self .mode == "full" :
60- shape = n + k - 1
61- else : # mode == "valid":
62- shape = maximum (n , k ) - minimum (n , k ) + 1
66+ shape_valid = maximum (n , k ) - minimum (n , k ) + 1
67+ shape_full = n + k - 1
68+ shape = switch (full_mode , shape_full , shape_valid )
6369 return [[shape ]]
6470
71+ def connection_pattern (self , node ):
72+ return [[True ], [True ], [False ]]
73+
6574 def L_op (self , inputs , outputs , output_grads ):
66- in1 , in2 = inputs
75+ in1 , in2 , full_mode = inputs
6776 [grad ] = output_grads
6877
69- if self .mode == "full" :
70- valid_conv = type (self )(mode = "valid" )
71- in1_bar = valid_conv (grad , in2 [::- 1 ])
72- in2_bar = valid_conv (grad , in1 [::- 1 ])
78+ n = in1 .shape [0 ]
79+ k = in2 .shape [0 ]
7380
74- else : # mode == "valid":
75- full_conv = type (self )(mode = "full" )
76- n = in1 .shape [0 ]
77- k = in2 .shape [0 ]
78- kmn = maximum (0 , k - n )
79- nmk = maximum (0 , n - k )
80- # We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
81- # Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
82- # There is a rewrite that optimizes this case when n, k are static
83- in1_bar = full_conv (grad , in2 [::- 1 ])
84- in1_bar = in1_bar [kmn : in1_bar .shape [0 ] - kmn ]
85- in2_bar = full_conv (grad , in1 [::- 1 ])
86- in2_bar = in2_bar [nmk : in2_bar .shape [0 ] - nmk ]
87-
88- return [in1_bar , in2_bar ]
81+ # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
82+ # The expression below is equivalent to ~(full_mode | (k >= n))
83+ full_mode_in1_bar = ~ full_mode & (k < n )
84+ # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
85+ # The expression below is equivalent to ~(full_mode | (n >= k))
86+ full_mode_in2_bar = ~ full_mode & (n < k )
87+
88+ return [
89+ self (grad , in2 [::- 1 ], full_mode_in1_bar ),
90+ self (grad , in1 [::- 1 ], full_mode_in2_bar ),
91+ DisconnectedType ()(),
92+ ]
8993
9094 def c_code_cache_version (self ):
91- return ( 1 ,)
95+ return None # (2 ,)
9296
9397 def c_code (self , node , name , inputs , outputs , sub ):
94- # raise NotImplementedError()
95- in1 , in2 = inputs
98+ in1 , in2 , full_mode = inputs
9699 [out ] = outputs
97- mode_str = self .mode
98-
99- if mode_str == "full" :
100- np_mode_val = 2 # NPY_CONVOLVE_FULL
101- elif mode_str == "valid" :
102- np_mode_val = 0 # NPY_CONVOLVE_VALID
103- else :
104- # This case should ideally be prevented by __init__ or make_node
105- raise ValueError (f"Unsupported mode { mode_str } " )
106100
107101 code = f"""
108102 {{
@@ -158,7 +152,7 @@ def c_code(self, node, name, inputs, outputs, sub):
158152
159153 // TODO: Use lower level implementation that allows reusing the output buffer
160154 Py_XDECREF({ out } );
161- { out } = (PyArrayObject*) PyArray_Correlate2((PyObject*){ in1 } , (PyObject*)in2_flipped_view, { np_mode_val } );
155+ { out } = (PyArrayObject*) PyArray_Correlate2((PyObject*){ in1 } , (PyObject*)in2_flipped_view, { full_mode } ? 2 : 0 );
162156 Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
163157 if (!{ out } ) {{
164158 // PyArray_Correlate already set an error
@@ -169,6 +163,9 @@ def c_code(self, node, name, inputs, outputs, sub):
169163 return code
170164
171165
166+ blockwise_convolve_1d = Blockwise (Convolve1d ())
167+
168+
172169def convolve1d (
173170 in1 : "TensorLike" ,
174171 in2 : "TensorLike" ,
@@ -212,4 +209,5 @@ def convolve1d(
212209 )
213210 mode = "valid"
214211
215- return cast (TensorVariable , Blockwise (Convolve1d (mode = mode ))(in1 , in2 ))
212+ full_mode = as_scalar (np .bool_ (mode == "full" ))
213+ return cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , full_mode ))
0 commit comments