Skip to content

Commit 39dda67

Browse files
authored
Support second order for Enzyme, take 2 (#285)
* Add nested mechanism * Mode * Typo * Tests passing * Remove additional tests FD+Enzyme * Add printing * Remove printing * More code coverage * Fix Heisenbug * Fix Heisenbug * Debump
1 parent c6aaabe commit 39dda67

File tree

23 files changed

+332
-181
lines changed

23 files changed

+332
-181
lines changed

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.5.2"
4+
version = "0.5.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -88,6 +88,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8888
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
8989
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
9090
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
91+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
9192
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
9293
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
9394
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -106,5 +107,6 @@ test = [
106107
"SparseArrays",
107108
"SparseConnectivityTracer",
108109
"SparseMatrixColorings",
110+
"StableRNGs",
109111
"Test",
110112
]

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,40 @@ using Enzyme:
2222
DuplicatedNoNeed,
2323
Forward,
2424
ForwardMode,
25+
Mode,
2526
Reverse,
2627
ReverseWithPrimal,
2728
ReverseSplitWithPrimal,
2829
ReverseMode,
2930
autodiff,
31+
autodiff_deferred,
32+
autodiff_deferred_thunk,
3033
autodiff_thunk,
3134
chunkedonehot,
3235
gradient,
3336
gradient!,
3437
jacobian,
3538
make_zero
3639

37-
const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode}
38-
const AutoForwardOrNothingEnzyme = Union{AutoEnzyme{<:ForwardMode},AutoEnzyme{Nothing}}
39-
const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode}
40-
const AutoReverseOrNothingEnzyme = Union{AutoEnzyme{<:ReverseMode},AutoEnzyme{Nothing}}
40+
struct AutoDeferredEnzyme{M} <: ADTypes.AbstractADType
41+
mode::M
42+
end
43+
44+
ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode))
45+
46+
DI.backend_package_name(::AutoDeferredEnzyme) = "DeferredEnzyme"
47+
48+
DI.nested(backend::AutoEnzyme) = AutoDeferredEnzyme(backend.mode)
4149

42-
forward_mode(backend::AutoEnzyme{<:ForwardMode}) = backend.mode
43-
forward_mode(::AutoEnzyme{Nothing}) = Forward
50+
const AnyAutoEnzyme{M} = Union{AutoEnzyme{M},AutoDeferredEnzyme{M}}
4451

45-
reverse_mode(backend::AutoEnzyme{<:ReverseMode}) = backend.mode
46-
reverse_mode(::AutoEnzyme{Nothing}) = Reverse
52+
# forward mode if possible
53+
forward_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode
54+
forward_mode(::AnyAutoEnzyme{Nothing}) = Forward
55+
56+
# reverse mode if possible
57+
reverse_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode
58+
reverse_mode(::AnyAutoEnzyme{Nothing}) = Reverse
4759

4860
DI.check_available(::AutoEnzyme) = true
4961

@@ -54,12 +66,6 @@ function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T
5466
return b
5567
end
5668

57-
function zero_sametype!(x_target, x)
58-
x_sametype = convert(typeof(x), x_target)
59-
x_sametype .= zero(eltype(x_sametype))
60-
return x_sametype
61-
end
62-
6369
include("forward_onearg.jl")
6470
include("forward_twoarg.jl")
6571

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,55 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras()
3+
function DI.prepare_pushforward(f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx)
4+
return NoPushforwardExtras()
5+
end
46

57
function DI.value_and_pushforward(
6-
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
8+
f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras
79
)
810
dx_sametype = convert(typeof(x), dx)
9-
y, new_dy = autodiff(
10-
forward_mode(backend), Const(f), Duplicated, Duplicated(x, dx_sametype)
11-
)
11+
x_and_dx = Duplicated(x, dx_sametype)
12+
y, new_dy = if backend isa AutoDeferredEnzyme
13+
autodiff_deferred(forward_mode(backend), f, Duplicated, x_and_dx)
14+
else
15+
autodiff(forward_mode(backend), Const(f), Duplicated, x_and_dx)
16+
end
1217
return y, new_dy
1318
end
1419

1520
function DI.pushforward(
16-
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
21+
f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras
1722
)
1823
dx_sametype = convert(typeof(x), dx)
19-
new_dy = only(
20-
autodiff(
21-
forward_mode(backend), Const(f), DuplicatedNoNeed, Duplicated(x, dx_sametype)
22-
),
23-
)
24+
x_and_dx = Duplicated(x, dx_sametype)
25+
new_dy = if backend isa AutoDeferredEnzyme
26+
only(autodiff_deferred(forward_mode(backend), f, DuplicatedNoNeed, x_and_dx))
27+
else
28+
only(autodiff(forward_mode(backend), Const(f), DuplicatedNoNeed, x_and_dx))
29+
end
2430
return new_dy
2531
end
2632

2733
function DI.value_and_pushforward!(
28-
f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras
34+
f,
35+
dy,
36+
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
37+
x,
38+
dx,
39+
extras::NoPushforwardExtras,
2940
)
3041
# dy cannot be passed anyway
3142
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
3243
return y, copyto!(dy, new_dy)
3344
end
3445

3546
function DI.pushforward!(
36-
f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras
47+
f,
48+
dy,
49+
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
50+
x,
51+
dx,
52+
extras::NoPushforwardExtras,
3753
)
3854
# dy cannot be passed anyway
3955
return copyto!(dy, DI.pushforward(f, backend, x, dx, extras))
@@ -45,34 +61,34 @@ struct EnzymeForwardGradientExtras{C,O} <: GradientExtras
4561
shadow::O
4662
end
4763

48-
function DI.prepare_gradient(f, ::AutoForwardEnzyme, x)
64+
function DI.prepare_gradient(f, ::AutoEnzyme{<:ForwardMode}, x)
4965
C = pick_chunksize(length(x))
5066
shadow = chunkedonehot(x, Val(C))
5167
return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow)
5268
end
5369

5470
function DI.gradient(
55-
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
71+
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
5672
) where {C}
5773
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
5874
return reshape(collect(grad_tup), size(x))
5975
end
6076

6177
function DI.value_and_gradient(
62-
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras
78+
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras
6379
)
6480
return f(x), DI.gradient(f, backend, x, extras)
6581
end
6682

6783
function DI.gradient!(
68-
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
84+
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
6985
) where {C}
7086
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
7187
return copyto!(grad, grad_tup)
7288
end
7389

7490
function DI.value_and_gradient!(
75-
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
91+
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
7692
) where {C}
7793
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
7894
return f(x), copyto!(grad, grad_tup)
@@ -84,14 +100,17 @@ struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras
84100
shadow::O
85101
end
86102

87-
function DI.prepare_jacobian(f, ::AutoForwardOrNothingEnzyme, x)
103+
function DI.prepare_jacobian(f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x)
88104
C = pick_chunksize(length(x))
89105
shadow = chunkedonehot(x, Val(C))
90106
return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow)
91107
end
92108

93109
function DI.jacobian(
94-
f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras{C}
110+
f,
111+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
112+
x,
113+
extras::EnzymeForwardOneArgJacobianExtras{C},
95114
) where {C}
96115
jac_wrongshape = jacobian(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
97116
nx = length(x)
@@ -100,15 +119,18 @@ function DI.jacobian(
100119
end
101120

102121
function DI.value_and_jacobian(
103-
f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
122+
f,
123+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
124+
x,
125+
extras::EnzymeForwardOneArgJacobianExtras,
104126
)
105127
return f(x), DI.jacobian(f, backend, x, extras)
106128
end
107129

108130
function DI.jacobian!(
109131
f,
110132
jac,
111-
backend::AutoForwardOrNothingEnzyme,
133+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
112134
x,
113135
extras::EnzymeForwardOneArgJacobianExtras,
114136
)
@@ -118,7 +140,7 @@ end
118140
function DI.value_and_jacobian!(
119141
f,
120142
jac,
121-
backend::AutoForwardOrNothingEnzyme,
143+
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
122144
x,
123145
extras::EnzymeForwardOneArgJacobianExtras,
124146
)
Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f!, y, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras()
3+
function DI.prepare_pushforward(f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx)
4+
return NoPushforwardExtras()
5+
end
46

57
function DI.value_and_pushforward(
6-
f!, y, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
8+
f!,
9+
y,
10+
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
11+
x,
12+
dx,
13+
::NoPushforwardExtras,
714
)
815
dx_sametype = convert(typeof(x), dx)
9-
dy_sametype = zero(y)
10-
autodiff(
11-
forward_mode(backend),
12-
Const(f!),
13-
Const,
14-
Duplicated(y, dy_sametype),
15-
Duplicated(x, dx_sametype),
16-
)
16+
dy_sametype = make_zero(y)
17+
y_and_dy = Duplicated(y, dy_sametype)
18+
x_and_dx = Duplicated(x, dx_sametype)
19+
if backend isa AutoDeferredEnzyme
20+
autodiff_deferred(forward_mode(backend), f!, Const, y_and_dy, x_and_dx)
21+
else
22+
autodiff(forward_mode(backend), Const(f!), Const, y_and_dy, x_and_dx)
23+
end
1724
return y, dy_sametype
1825
end

0 commit comments

Comments
 (0)