Skip to content

Commit 7f6af69

Browse files
committed
Add EnzymeRule for conv
1 parent 607de4b commit 7f6af69

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

ext/NNlibEnzymeExt.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
module NNlibEnzymeExt
2+
3+
using NNlib
4+
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
5+
6+
using Enzyme
7+
8+
using EnzymeCore
9+
10+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims::Const; kwargs...)
11+
) where {OutType, RT}
12+
13+
@assert !(OutType <: Const)
14+
if OutType <: Duplicated || OutType <: DuplicatedNoNeed
15+
func.val(y.val, x.val, y.val, cdims.val; kwargs...)
16+
end
17+
18+
dres = if EnzymeRules.width(config) == 1
19+
func.val(prob.dval, alg.val; kwargs...)
20+
else
21+
ntuple(Val(EnzymeRules.width(config))) do i
22+
Base.@_inline_meta
23+
func.val(prob.dval[i], alg.val; kwargs...)
24+
end
25+
end
26+
27+
primal = if EnzymeRules.needs_primal(config)
28+
y.val
29+
else
30+
nothing
31+
end
32+
shadow = if EnzymeRules.needs_shadow(config)
33+
y.dval
34+
else
35+
nothing
36+
end
37+
38+
# Cache x if its overwritten and w is active (and thus required)
39+
cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing
40+
41+
# Cache w if its overwritten and x is active (and thus required)
42+
cache_w = ( EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing
43+
44+
cache = (cache_x, cache_w)
45+
46+
return EnzymeCore.EnzymeRules.AugmentedReturn(y.val, y.dval, cache)
47+
end
48+
49+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims::Const; kwargs...) where {RT}
50+
cache_x, cache_w = cache
51+
52+
# Don't cache x if not overwritten and w is active (and thus required)
53+
if !(typeof(w) <: Const)
54+
if !EnzymeRules.overwritten(config)[3]
55+
cache_x = x.val
56+
end
57+
end
58+
59+
# Don't cache w if not overwritten and x is active (and thus required)
60+
if !(typeof(x) <: Const)
61+
if !EnzymeRules.overwritten(config)[4]
62+
cache_w = w.val
63+
end
64+
end
65+
66+
dys = y.dval
67+
dxs = (typeof(x) <: Const) ? nothing : x.dval
68+
dws = (typeof(w) <: Const) ? nothing : w.dval
69+
70+
if EnzymeRules.width(config) == 1
71+
dys = (dys,)
72+
dxs = (dxs,)
73+
dws = (dws,)
74+
end
75+
76+
for (dy, dx, dw) in (dys, dxs, dws)
77+
if !(typeof(x) <: Const)
78+
# dx += grad wrt x
79+
NNlib.∇conv_data!(dx, dy, cache_w, cdims; alpha=1, beta=1, kwargs...)
80+
end
81+
if !(typeof(y) <: Const)
82+
# dw += grad wrt w
83+
NNlib.∇conv_filter!(dw, cache_x, dy, cdims; alpha=1, beta=1, kwargs...)
84+
end
85+
end
86+
87+
return (nothing, nothing, nothing, nothing)
88+
end

src/NNlib.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,12 @@ include("impl/depthwiseconv_im2col.jl")
123123
include("impl/pooling_direct.jl")
124124
include("deprecations.jl")
125125

126+
function __init__()
127+
@static if !isdefined(Base, :get_extension)
128+
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
129+
include("../ext/NNlibEnzymeExt.jl")
130+
end
131+
end
132+
end
133+
126134
end # module NNlib

0 commit comments

Comments
 (0)