6
6
7
7
from pytensor .gradient import DisconnectedType
8
8
from pytensor .graph import Apply , Constant
9
+ from pytensor .graph .op import Op
9
10
from pytensor .link .c .op import COp
10
11
from pytensor .scalar import as_scalar
11
12
from pytensor .scalar .basic import upcast
@@ -220,18 +221,16 @@ class Convolve2D(Op):
220
221
221
222
def __init__ (
222
223
self ,
223
- mode : Literal ["full" , "valid" , "same" ] = "full" ,
224
+ mode : Literal ["full" , "valid" ] = "full" ,
224
225
boundary : Literal ["fill" , "wrap" , "symm" ] = "fill" ,
225
226
fillvalue : float | int = 0 ,
226
227
):
227
- if mode not in ("full" , "valid" , "same" ):
228
+ if mode not in ("full" , "valid" ):
228
229
raise ValueError (f"Invalid mode: { mode } " )
229
- if boundary not in ("fill" , "wrap" , "symm" ):
230
- raise ValueError (f"Invalid boundary: { boundary } " )
231
230
232
231
self .mode = mode
233
- self .boundary = boundary
234
232
self .fillvalue = fillvalue
233
+ self .boundary = boundary
235
234
236
235
def make_node (self , in1 , in2 ):
237
236
in1 , in2 = map (as_tensor_variable , (in1 , in2 ))
@@ -262,8 +261,13 @@ def make_node(self, in1, in2):
262
261
263
262
def perform (self , node , inputs , outputs ):
264
263
in1 , in2 = inputs
264
+
265
+ # if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
266
+ # outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
267
+ #
268
+ # else:
265
269
outputs [0 ][0 ] = scipy_convolve2d (
266
- in1 , in2 , mode = self .mode , boundary = self .boundary , fillvalue = self .fillvalue
270
+ in1 , in2 , mode = self .mode , fillvalue = self .fillvalue , boundary = self .boundary
267
271
)
268
272
269
273
def infer_shape (self , fgraph , node , shapes ):
@@ -284,7 +288,18 @@ def infer_shape(self, fgraph, node, shapes):
284
288
return [shape ]
285
289
286
290
def L_op (self , inputs , outputs , output_grads ):
287
- raise NotImplementedError
291
+ in1 , in2 = inputs
292
+ incoming_grads = output_grads [0 ]
293
+
294
+ if self .mode == "full" :
295
+ prop_dict = self ._props_dict ()
296
+ prop_dict ["mode" ] = "valid"
297
+ conv_valid = type (self )(** prop_dict )
298
+
299
+ in1_grad = conv_valid (in2 , incoming_grads )
300
+ in2_grad = conv_valid (in1 , incoming_grads )
301
+
302
+ return [in1_grad , in2_grad ]
288
303
289
304
290
305
def convolve2d (
@@ -325,6 +340,9 @@ def convolve2d(
325
340
in1 = as_tensor_variable (in1 )
326
341
in2 = as_tensor_variable (in2 )
327
342
343
+ # TODO: Handle boundaries symbolically
344
+ # TODO: Handle 'same' symbolically
345
+
328
346
blockwise_convolve = Blockwise (
329
347
Convolve2D (mode = mode , boundary = boundary , fillvalue = fillvalue )
330
348
)
0 commit comments