Skip to content

Commit 5fb8d27

Browse files
authored
Merge pull request #387 from vincentmolin/rruleconv
add `rrule(::typeof(∇conv_filter)`
2 parents c0b4b8b + ae4866e commit 5fb8d27

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/conv.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,19 @@ for conv in [:conv, :depthwiseconv]
329329
end
330330
end
331331

332+
function rrule(::typeof(∇conv_filter), x, dy, cdims; kw...)
333+
function ∇conv_filter_pullback(Δ)
334+
Δ1 = colmajor(unthunk(Δ))
335+
return (
336+
NoTangent(),
337+
@thunk(∇conv_data(dy, Δ1, cdims, kw...)),
338+
@thunk(conv(x, Δ1, cdims, kw...)),
339+
NoTangent(),
340+
)
341+
end
342+
return ∇conv_filter(x, dy, cdims; kw...), ∇conv_filter_pullback
343+
end
344+
332345
# Use NNPACK if it is available and the operation is supported
333346
# commented out 'till proper benchmarking and more correctness test are performed
334347
# if is_nnpack_available()

test/conv.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,10 @@ end
734734
# else
735735
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
736736
# end
737+
gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y)
738+
if spatial_rank < 3
739+
gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y)
740+
end
737741

738742
dcdims = DepthwiseConvDims(x, w)
739743
gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)

0 commit comments

Comments
 (0)