Skip to content

Commit 6dd08a4

Browse files
committed
Add missing file
1 parent 9b2b6c2 commit 6dd08a4

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

src/enzyme.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import EnzymeCore
2+
3+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}
4+
5+
@assert !(OutType <: EnzymeCore.Const)
6+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
7+
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
8+
end
9+
10+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
11+
y.val
12+
else
13+
nothing
14+
end
15+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
16+
y.dval
17+
else
18+
nothing
19+
end
20+
21+
# Cache x if its overwritten and w is active (and thus required)
22+
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing
23+
24+
# Cache w if its overwritten and x is active (and thus required)
25+
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing
26+
27+
cache = (cache_x, cache_w)
28+
29+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
30+
end
31+
32+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
33+
cache_x, cache_w = cache
34+
35+
# Don't cache x if not overwritten and w is active (and thus required)
36+
if !(typeof(w) <: EnzymeCore.Const)
37+
if !EnzymeCore.EnzymeRules.overwritten(config)[3]
38+
cache_x = x.val
39+
end
40+
end
41+
42+
# Don't cache w if not overwritten and x is active (and thus required)
43+
if !(typeof(x) <: EnzymeCore.Const)
44+
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
45+
cache_w = w.val
46+
end
47+
end
48+
49+
dys = y.dval
50+
dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
51+
dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval
52+
53+
if EnzymeCore.EnzymeRules.width(config) == 1
54+
dys = (dys,)
55+
dxs = (dxs,)
56+
dws = (dws,)
57+
end
58+
59+
for (dy, dx, dw) in zip(dys, dxs, dws)
60+
if !(typeof(x) <: EnzymeCore.Const) && dx !== x
61+
# dx += grad wrt x
62+
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
63+
end
64+
if !(typeof(w) <: EnzymeCore.Const) && dw !== w
65+
# dw += grad wrt w
66+
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
67+
end
68+
dy .= 0
69+
end
70+
71+
return (nothing, nothing, nothing, nothing)
72+
end

0 commit comments

Comments
 (0)