1
+ module NNlibEnzymeExt
2
+
3
+ using NNlib
4
+ isdefined (Base, :get_extension ) ? (import Enzyme) : (import .. Enzyme)
5
+
6
+ using Enzyme
7
+
8
+ using EnzymeCore
9
+
10
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: Const{typeof(NNlib.conv!)} , :: Type{RT} , y:: OutType , x, w, cdims:: Const ; kwargs... )
11
+ ) where {OutType, RT}
12
+
13
+ @assert ! (OutType <: Const )
14
+ if OutType <: Duplicated || OutType <: DuplicatedNoNeed
15
+ func. val (y. val, x. val, y. val, cdims. val; kwargs... )
16
+ end
17
+
18
+ dres = if EnzymeRules. width (config) == 1
19
+ func. val (prob. dval, alg. val; kwargs... )
20
+ else
21
+ ntuple (Val (EnzymeRules. width (config))) do i
22
+ Base. @_inline_meta
23
+ func. val (prob. dval[i], alg. val; kwargs... )
24
+ end
25
+ end
26
+
27
+ primal = if EnzymeRules. needs_primal (config)
28
+ y. val
29
+ else
30
+ nothing
31
+ end
32
+ shadow = if EnzymeRules. needs_shadow (config)
33
+ y. dval
34
+ else
35
+ nothing
36
+ end
37
+
38
+ # Cache x if its overwritten and w is active (and thus required)
39
+ cache_x = ( EnzymeRules. overwritten (config)[3 ] && ! (typeof (w) <: Const ) ) ? copy (x. val) : nothing
40
+
41
+ # Cache w if its overwritten and x is active (and thus required)
42
+ cache_w = ( EnzymeRules. overwritten (config)[4 ] && ! (typeof (x) <: Const ) ) ? copy (w. val) : nothing
43
+
44
+ cache = (cache_x, cache_w)
45
+
46
+ return EnzymeCore. EnzymeRules. AugmentedReturn (y. val, y. dval, cache)
47
+ end
48
+
49
+ function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(NNlib.conv!)} , :: Type{RT} , cache, y, x, w, cdims:: Const ; kwargs... ) where {RT}
50
+ cache_x, cache_w = cache
51
+
52
+ # Don't cache x if not overwritten and w is active (and thus required)
53
+ if ! (typeof (w) <: Const )
54
+ if ! EnzymeRules. overwritten (config)[3 ]
55
+ cache_x = x. val
56
+ end
57
+ end
58
+
59
+ # Don't cache w if not overwritten and x is active (and thus required)
60
+ if ! (typeof (x) <: Const )
61
+ if ! EnzymeRules. overwritten (config)[4 ]
62
+ cache_w = w. val
63
+ end
64
+ end
65
+
66
+ dys = y. dval
67
+ dxs = (typeof (x) <: Const ) ? nothing : x. dval
68
+ dws = (typeof (w) <: Const ) ? nothing : w. dval
69
+
70
+ if EnzymeRules. width (config) == 1
71
+ dys = (dys,)
72
+ dxs = (dxs,)
73
+ dws = (dws,)
74
+ end
75
+
76
+ for (dy, dx, dw) in (dys, dxs, dws)
77
+ if ! (typeof (x) <: Const )
78
+ # dx += grad wrt x
79
+ NNlib.∇conv_data! (dx, dy, cache_w, cdims; alpha= 1 , beta= 1 , kwargs... )
80
+ end
81
+ if ! (typeof (y) <: Const )
82
+ # dw += grad wrt w
83
+ NNlib.∇conv_filter! (dw, cache_x, dy, cdims; alpha= 1 , beta= 1 , kwargs... )
84
+ end
85
+ end
86
+
87
+ return (nothing , nothing , nothing , nothing )
88
+ end
0 commit comments