1
1
# # Pushforward
2
2
3
- DI. prepare_pushforward (f, :: AutoForwardOrNothingEnzyme , x, dx) = NoPushforwardExtras ()
3
+ function DI. prepare_pushforward (f, :: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} , x, dx)
4
+ return NoPushforwardExtras ()
5
+ end
4
6
5
7
function DI. value_and_pushforward (
6
- f, backend:: AutoForwardOrNothingEnzyme , x, dx, :: NoPushforwardExtras
8
+ f, backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} , x, dx, :: NoPushforwardExtras
7
9
)
8
10
dx_sametype = convert (typeof (x), dx)
9
- y, new_dy = autodiff (
10
- forward_mode (backend), Const (f), Duplicated, Duplicated (x, dx_sametype)
11
- )
11
+ x_and_dx = Duplicated (x, dx_sametype)
12
+ y, new_dy = if backend isa AutoDeferredEnzyme
13
+ autodiff_deferred (forward_mode (backend), f, Duplicated, x_and_dx)
14
+ else
15
+ autodiff (forward_mode (backend), Const (f), Duplicated, x_and_dx)
16
+ end
12
17
return y, new_dy
13
18
end
14
19
15
20
function DI. pushforward (
16
- f, backend:: AutoForwardOrNothingEnzyme , x, dx, :: NoPushforwardExtras
21
+ f, backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} , x, dx, :: NoPushforwardExtras
17
22
)
18
23
dx_sametype = convert (typeof (x), dx)
19
- new_dy = only (
20
- autodiff (
21
- forward_mode (backend), Const (f), DuplicatedNoNeed, Duplicated (x, dx_sametype)
22
- ),
23
- )
24
+ x_and_dx = Duplicated (x, dx_sametype)
25
+ new_dy = if backend isa AutoDeferredEnzyme
26
+ only (autodiff_deferred (forward_mode (backend), f, DuplicatedNoNeed, x_and_dx))
27
+ else
28
+ only (autodiff (forward_mode (backend), Const (f), DuplicatedNoNeed, x_and_dx))
29
+ end
24
30
return new_dy
25
31
end
26
32
27
33
function DI. value_and_pushforward! (
28
- f, dy, backend:: AutoForwardOrNothingEnzyme , x, dx, extras:: NoPushforwardExtras
34
+ f,
35
+ dy,
36
+ backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} ,
37
+ x,
38
+ dx,
39
+ extras:: NoPushforwardExtras ,
29
40
)
30
41
# dy cannot be passed anyway
31
42
y, new_dy = DI. value_and_pushforward (f, backend, x, dx, extras)
32
43
return y, copyto! (dy, new_dy)
33
44
end
34
45
35
46
function DI. pushforward! (
36
- f, dy, backend:: AutoForwardOrNothingEnzyme , x, dx, extras:: NoPushforwardExtras
47
+ f,
48
+ dy,
49
+ backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} ,
50
+ x,
51
+ dx,
52
+ extras:: NoPushforwardExtras ,
37
53
)
38
54
# dy cannot be passed anyway
39
55
return copyto! (dy, DI. pushforward (f, backend, x, dx, extras))
@@ -45,34 +61,34 @@ struct EnzymeForwardGradientExtras{C,O} <: GradientExtras
45
61
shadow:: O
46
62
end
47
63
48
- function DI. prepare_gradient (f, :: AutoForwardEnzyme , x)
64
+ function DI. prepare_gradient (f, :: AutoEnzyme{<:ForwardMode} , x)
49
65
C = pick_chunksize (length (x))
50
66
shadow = chunkedonehot (x, Val (C))
51
67
return EnzymeForwardGradientExtras {C,typeof(shadow)} (shadow)
52
68
end
53
69
54
70
function DI. gradient (
55
- f, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras{C}
71
+ f, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C}
56
72
) where {C}
57
73
grad_tup = gradient (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
58
74
return reshape (collect (grad_tup), size (x))
59
75
end
60
76
61
77
function DI. value_and_gradient (
62
- f, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras
78
+ f, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras
63
79
)
64
80
return f (x), DI. gradient (f, backend, x, extras)
65
81
end
66
82
67
83
function DI. gradient! (
68
- f, grad, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras{C}
84
+ f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C}
69
85
) where {C}
70
86
grad_tup = gradient (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
71
87
return copyto! (grad, grad_tup)
72
88
end
73
89
74
90
function DI. value_and_gradient! (
75
- f, grad, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras{C}
91
+ f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C}
76
92
) where {C}
77
93
grad_tup = gradient (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
78
94
return f (x), copyto! (grad, grad_tup)
@@ -84,14 +100,17 @@ struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras
84
100
shadow:: O
85
101
end
86
102
87
- function DI. prepare_jacobian (f, :: AutoForwardOrNothingEnzyme , x)
103
+ function DI. prepare_jacobian (f, :: AutoEnzyme{<:Union{ForwardMode,Nothing}} , x)
88
104
C = pick_chunksize (length (x))
89
105
shadow = chunkedonehot (x, Val (C))
90
106
return EnzymeForwardOneArgJacobianExtras {C,typeof(shadow)} (shadow)
91
107
end
92
108
93
109
function DI. jacobian (
94
- f, backend:: AutoForwardOrNothingEnzyme , x, extras:: EnzymeForwardOneArgJacobianExtras{C}
110
+ f,
111
+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
112
+ x,
113
+ extras:: EnzymeForwardOneArgJacobianExtras{C} ,
95
114
) where {C}
96
115
jac_wrongshape = jacobian (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
97
116
nx = length (x)
@@ -100,15 +119,18 @@ function DI.jacobian(
100
119
end
101
120
102
121
function DI. value_and_jacobian (
103
- f, backend:: AutoForwardOrNothingEnzyme , x, extras:: EnzymeForwardOneArgJacobianExtras
122
+ f,
123
+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
124
+ x,
125
+ extras:: EnzymeForwardOneArgJacobianExtras ,
104
126
)
105
127
return f (x), DI. jacobian (f, backend, x, extras)
106
128
end
107
129
108
130
function DI. jacobian! (
109
131
f,
110
132
jac,
111
- backend:: AutoForwardOrNothingEnzyme ,
133
+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
112
134
x,
113
135
extras:: EnzymeForwardOneArgJacobianExtras ,
114
136
)
118
140
function DI. value_and_jacobian! (
119
141
f,
120
142
jac,
121
- backend:: AutoForwardOrNothingEnzyme ,
143
+ backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} ,
122
144
x,
123
145
extras:: EnzymeForwardOneArgJacobianExtras ,
124
146
)
0 commit comments