Skip to content

Commit 880ef1d

Browse files
conv2d gradient v0
1 parent 5c49c48 commit 880ef1d

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pytensor.gradient import DisconnectedType
88
from pytensor.graph import Apply, Constant
9+
from pytensor.graph.op import Op
910
from pytensor.link.c.op import COp
1011
from pytensor.scalar import as_scalar
1112
from pytensor.scalar.basic import upcast
@@ -220,18 +221,16 @@ class Convolve2D(Op):
220221

221222
def __init__(
222223
self,
223-
mode: Literal["full", "valid", "same"] = "full",
224+
mode: Literal["full", "valid"] = "full",
224225
boundary: Literal["fill", "wrap", "symm"] = "fill",
225226
fillvalue: float | int = 0,
226227
):
227-
if mode not in ("full", "valid", "same"):
228+
if mode not in ("full", "valid"):
228229
raise ValueError(f"Invalid mode: {mode}")
229-
if boundary not in ("fill", "wrap", "symm"):
230-
raise ValueError(f"Invalid boundary: {boundary}")
231230

232231
self.mode = mode
233-
self.boundary = boundary
234232
self.fillvalue = fillvalue
233+
self.boundary = boundary
235234

236235
def make_node(self, in1, in2):
237236
in1, in2 = map(as_tensor_variable, (in1, in2))
@@ -262,8 +261,13 @@ def make_node(self, in1, in2):
262261

263262
def perform(self, node, inputs, outputs):
264263
in1, in2 = inputs
264+
265+
# if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
266+
# outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
267+
#
268+
# else:
265269
outputs[0][0] = scipy_convolve2d(
266-
in1, in2, mode=self.mode, boundary=self.boundary, fillvalue=self.fillvalue
270+
in1, in2, mode=self.mode, fillvalue=self.fillvalue, boundary=self.boundary
267271
)
268272

269273
def infer_shape(self, fgraph, node, shapes):
@@ -284,7 +288,18 @@ def infer_shape(self, fgraph, node, shapes):
284288
return [shape]
285289

286290
def L_op(self, inputs, outputs, output_grads):
287-
raise NotImplementedError
291+
in1, in2 = inputs
292+
incoming_grads = output_grads[0]
293+
294+
if self.mode == "full":
295+
prop_dict = self._props_dict()
296+
prop_dict["mode"] = "valid"
297+
conv_valid = type(self)(**prop_dict)
298+
299+
in1_grad = conv_valid(in2, incoming_grads)
300+
in2_grad = conv_valid(in1, incoming_grads)
301+
302+
return [in1_grad, in2_grad]
288303

289304

290305
def convolve2d(
@@ -325,6 +340,9 @@ def convolve2d(
325340
in1 = as_tensor_variable(in1)
326341
in2 = as_tensor_variable(in2)
327342

343+
# TODO: Handle boundaries symbolically
344+
# TODO: Handle 'same' symbolically
345+
328346
blockwise_convolve = Blockwise(
329347
Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue)
330348
)

tests/tensor/signal/test_conv.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,36 @@ def test_convolve2d(kernel_shape, data_shape, mode, boundary):
163163
scipy_convolve2d(
164164
data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0
165165
),
166-
rtol=1e-6 if config.floatX == "float32" else 1e-15,
166+
atol=1e-6 if config.floatX == "float32" else 1e-8,
167167
)
168+
169+
utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val], eps=1e-4)
170+
171+
172+
# @pytest.mark.parametrize(
173+
# "data_shape, kernel_shape", [[(10, 1, 8, 8), (3, 1, 3, 3)], # 8x8 grayscale
174+
# [(1000, 1, 8, 8), (3, 1, 1, 3)], # same, but with 1000 images
175+
# [(10, 3, 64, 64), (10, 3, 8, 8)], # 64x64 RGB
176+
# [(1000, 3, 64, 64), (10, 3, 8, 8)], # same, but with 1000 images
177+
# [(3, 100, 100, 100), (250, 100, 50, 50)]], # Very large, deep hidden layer or something
178+
#
179+
# ids=lambda x: f"data_shape={x[0]}, kernel_shape={x[1]}"
180+
# )
181+
# @pytest.mark.parametrize('func', ['new', 'theano'], ids=['new-impl', 'theano-impl'])
182+
# def test_conv2d_nn_benchmark(data_shape, kernel_shape, func, benchmark):
183+
# import pytensor.tensor as pt
184+
# x = pt.tensor("x", shape=data_shape)
185+
# y = pt.tensor("y", shape=kernel_shape)
186+
#
187+
# if func == 'new':
188+
# out = nn_conv2d(x, y)
189+
# else:
190+
# out = conv2d(input=x, filters=y, border_mode="valid")
191+
#
192+
# rng = np.random.default_rng(38)
193+
# x_test = rng.normal(size=data_shape).astype(x.dtype)
194+
# y_test = rng.normal(size=kernel_shape).astype(y.dtype)
195+
#
196+
# fn = function([x, y], out, trust_input=True)
197+
#
198+
# benchmark(fn, x_test, y_test)

0 commit comments

Comments
 (0)