Skip to content

Commit 5055232

Browse files
committed
Add missing file
1 parent 113c47b commit 5055232

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

ext/NNlibEnzymeExt/NNlibEnzymeExt.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
module NNlibEnzymeExt
2+
3+
using NNlib
4+
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
5+
6+
using EnzymeCore
7+
8+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}
9+
10+
@assert !(OutType <: Const)
11+
if OutType <: Duplicated || OutType <: DuplicatedNoNeed
12+
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
13+
end
14+
15+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
16+
y.val
17+
else
18+
nothing
19+
end
20+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
21+
y.dval
22+
else
23+
nothing
24+
end
25+
26+
# Cache x if its overwritten and w is active (and thus required)
27+
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: Const) ) ? copy(x.val) : nothing
28+
29+
# Cache w if its overwritten and x is active (and thus required)
30+
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: Const) ) ? copy(w.val) : nothing
31+
32+
cache = (cache_x, cache_w)
33+
34+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
35+
end
36+
37+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
38+
cache_x, cache_w = cache
39+
40+
# Don't cache x if not overwritten and w is active (and thus required)
41+
if !(typeof(w) <: Const)
42+
if !EnzymeCore.EnzymeRules.overwritten(config)[3]
43+
cache_x = x.val
44+
end
45+
end
46+
47+
# Don't cache w if not overwritten and x is active (and thus required)
48+
if !(typeof(x) <: Const)
49+
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
50+
cache_w = w.val
51+
end
52+
end
53+
54+
dys = y.dval
55+
dxs = (typeof(x) <: Const) ? dys : x.dval
56+
dws = (typeof(w) <: Const) ? dys : w.dval
57+
58+
if EnzymeCore.EnzymeRules.width(config) == 1
59+
dys = (dys,)
60+
dxs = (dxs,)
61+
dws = (dws,)
62+
end
63+
64+
for (dy, dx, dw) in zip(dys, dxs, dws)
65+
if !(typeof(x) <: Const) && dx !== x
66+
# dx += grad wrt x
67+
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
68+
end
69+
if !(typeof(w) <: Const) && dw !== w
70+
# dw += grad wrt w
71+
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
72+
end
73+
dy .= 0
74+
end
75+
76+
return (nothing, nothing, nothing, nothing)
77+
end
78+
79+
end

0 commit comments

Comments
 (0)