1
+ module NNlibEnzymeExt
2
+
3
+ using NNlib
4
+ isdefined (Base, :get_extension ) ? (import Enzyme) : (import .. Enzyme)
5
+
6
+ using EnzymeCore
7
+
8
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: Const{typeof(NNlib.conv!)} , :: Type{RT} , y:: OutType , x, w, cdims; kwargs... ) where {OutType, RT}
9
+
10
+ @assert ! (OutType <: Const )
11
+ if OutType <: Duplicated || OutType <: DuplicatedNoNeed
12
+ func. val (y. val, x. val, w. val, cdims. val; kwargs... )
13
+ end
14
+
15
+ primal = if EnzymeCore. EnzymeRules. needs_primal (config)
16
+ y. val
17
+ else
18
+ nothing
19
+ end
20
+ shadow = if EnzymeCore. EnzymeRules. needs_shadow (config)
21
+ y. dval
22
+ else
23
+ nothing
24
+ end
25
+
26
+ # Cache x if its overwritten and w is active (and thus required)
27
+ cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ] && ! (typeof (w) <: Const ) ) ? copy (x. val) : nothing
28
+
29
+ # Cache w if its overwritten and x is active (and thus required)
30
+ cache_w = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (x) <: Const ) ) ? copy (w. val) : nothing
31
+
32
+ cache = (cache_x, cache_w)
33
+
34
+ return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache)
35
+ end
36
+
37
+ function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(NNlib.conv!)} , :: Type{RT} , cache, y, x, w, cdims; kwargs... ) where {RT}
38
+ cache_x, cache_w = cache
39
+
40
+ # Don't cache x if not overwritten and w is active (and thus required)
41
+ if ! (typeof (w) <: Const )
42
+ if ! EnzymeCore. EnzymeRules. overwritten (config)[3 ]
43
+ cache_x = x. val
44
+ end
45
+ end
46
+
47
+ # Don't cache w if not overwritten and x is active (and thus required)
48
+ if ! (typeof (x) <: Const )
49
+ if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
50
+ cache_w = w. val
51
+ end
52
+ end
53
+
54
+ dys = y. dval
55
+ dxs = (typeof (x) <: Const ) ? dys : x. dval
56
+ dws = (typeof (w) <: Const ) ? dys : w. dval
57
+
58
+ if EnzymeCore. EnzymeRules. width (config) == 1
59
+ dys = (dys,)
60
+ dxs = (dxs,)
61
+ dws = (dws,)
62
+ end
63
+
64
+ for (dy, dx, dw) in zip (dys, dxs, dws)
65
+ if ! (typeof (x) <: Const ) && dx != = x
66
+ # dx += grad wrt x
67
+ NNlib.∇conv_data! (dx, dy, cache_w, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
68
+ end
69
+ if ! (typeof (w) <: Const ) && dw != = w
70
+ # dw += grad wrt w
71
+ NNlib.∇conv_filter! (dw, cache_x, dy, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
72
+ end
73
+ dy .= 0
74
+ end
75
+
76
+ return (nothing , nothing , nothing , nothing )
77
+ end
78
+
79
+ end
0 commit comments