Skip to content

Commit b985eb5

Browse files
authored
feat: support forward-mode Mooncake with AutoMooncakeForward (#813)
* feat: support forward-mode Mooncake [experimental] * Fix comma * Bump versions * Test rule * Format * Fix * Format * Fix coverage * Docs * Fix check prep * Fix * Add config
1 parent cfb1a94 commit b985eb5

File tree

9 files changed

+240
-15
lines changed

9 files changed

+240
-15
lines changed

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.4"
4+
version = "0.7.5"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -56,7 +56,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
5656
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
5757

5858
[compat]
59-
ADTypes = "1.13.0"
59+
ADTypes = "1.17.0"
6060
Aqua = "0.8.12"
6161
ChainRulesCore = "1.23.0"
6262
ComponentArrays = "0.15.27"
@@ -77,7 +77,7 @@ JET = "0.9"
7777
JLArrays = "0.2.0"
7878
JuliaFormatter = "1,2"
7979
LinearAlgebra = "1"
80-
Mooncake = "0.4.122"
80+
Mooncake = "0.4.147"
8181
Pkg = "1"
8282
PolyesterForwardDiff = "0.1.2"
8383
Random = "1"

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
1212
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
1313
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
1414
- [`AutoGTPSA`](@extref ADTypes.AutoGTPSA)
15-
- [`AutoMooncake`](@extref ADTypes.AutoMooncake)
15+
- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) (the latter is experimental)
1616
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
1717
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
1818
- [`AutoSymbolics`](@extref ADTypes.AutoSymbolics)
@@ -48,6 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato
4848
| `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
4949
| `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
5050
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
51+
| `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
5152
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
5253
| `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
5354
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
@@ -68,6 +69,7 @@ Moreover, each context type is supported by a specific subset of backends:
6869
| `AutoForwardDiff` |||
6970
| `AutoGTPSA` |||
7071
| `AutoMooncake` |||
72+
| `AutoMooncakeForward` |||
7173
| `AutoPolyesterForwardDiff` |||
7274
| `AutoReverseDiff` |||
7375
| `AutoSymbolics` |||
@@ -95,7 +97,7 @@ In general, using a forward outer backend over a reverse inner backend will yiel
9597
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
9698
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
9799
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
98-
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
100+
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl), or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend (e.g., [Zygote.jl](https://github.com/FluxML/Zygote.jl)).
99101

100102
## Implementations
101103

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
module DifferentiationInterfaceMooncakeExt
22

3-
using ADTypes: ADTypes, AutoMooncake
3+
using ADTypes: ADTypes, AutoMooncake, AutoMooncakeForward
44
import DifferentiationInterface as DI
55
using Mooncake:
66
Mooncake,
77
CoDual,
88
Config,
9+
Dual,
10+
prepare_derivative_cache,
911
prepare_gradient_cache,
1012
prepare_pullback_cache,
13+
primal,
14+
tangent,
1115
tangent_type,
16+
value_and_derivative!!,
1217
value_and_gradient!!,
1318
value_and_pullback!!,
19+
zero_dual,
1420
zero_tangent,
1521
rdata_type,
1622
fdata,
@@ -25,17 +31,17 @@ using Mooncake:
2531
_copy_output,
2632
_copy_to_output!!
2733

28-
DI.check_available(::AutoMooncake) = true
34+
const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}}
2935

30-
get_config(::AutoMooncake{Nothing}) = Config()
31-
get_config(backend::AutoMooncake{<:Config}) = backend.config
36+
DI.check_available(::AnyAutoMooncake{C}) where {C} = true
3237

33-
# tangents need to be copied before returning, otherwise they are still aliased in the cache
34-
mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x)
35-
mycopy(x) = deepcopy(x)
38+
get_config(::AnyAutoMooncake{Nothing}) = Config()
39+
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
3640

3741
include("onearg.jl")
3842
include("twoarg.jl")
43+
include("forward_onearg.jl")
44+
include("forward_twoarg.jl")
3945
include("differentiate_with.jl")
4046

4147
end
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
## Pushforward
2+
3+
struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG}
4+
_sig::Val{SIG}
5+
cache::Tcache
6+
dx_righttype::DX
7+
end
8+
9+
function DI.prepare_pushforward_nokwarg(
10+
strict::Val,
11+
f::F,
12+
backend::AutoMooncakeForward,
13+
x,
14+
tx::NTuple,
15+
contexts::Vararg{DI.Context,C};
16+
) where {F,C}
17+
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
18+
config = get_config(backend)
19+
cache = prepare_derivative_cache(
20+
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
21+
)
22+
dx_righttype = zero_tangent(x)
23+
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype)
24+
return prep
25+
end
26+
27+
function DI.value_and_pushforward(
28+
f::F,
29+
prep::MooncakeOneArgPushforwardPrep,
30+
backend::AutoMooncakeForward,
31+
x::X,
32+
tx::NTuple,
33+
contexts::Vararg{DI.Context,C};
34+
) where {F,C,X}
35+
DI.check_prep(f, prep, backend, x, tx, contexts...)
36+
ys_and_ty = map(tx) do dx
37+
dx_righttype =
38+
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
39+
y_dual = value_and_derivative!!(
40+
prep.cache,
41+
zero_dual(f),
42+
Dual(x, dx_righttype),
43+
map(zero_dual DI.unwrap, contexts)...,
44+
)
45+
y = primal(y_dual)
46+
dy = _copy_output(tangent(y_dual))
47+
return y, dy
48+
end
49+
y = first(ys_and_ty[1])
50+
ty = last.(ys_and_ty)
51+
return y, ty
52+
end
53+
54+
function DI.pushforward(
55+
f::F,
56+
prep::MooncakeOneArgPushforwardPrep,
57+
backend::AutoMooncakeForward,
58+
x,
59+
tx::NTuple,
60+
contexts::Vararg{DI.Context,C};
61+
) where {F,C}
62+
DI.check_prep(f, prep, backend, x, tx, contexts...)
63+
return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2]
64+
end
65+
66+
function DI.value_and_pushforward!(
67+
f::F,
68+
ty::NTuple,
69+
prep::MooncakeOneArgPushforwardPrep,
70+
backend::AutoMooncakeForward,
71+
x,
72+
tx::NTuple,
73+
contexts::Vararg{DI.Context,C};
74+
) where {F,C}
75+
DI.check_prep(f, prep, backend, x, tx, contexts...)
76+
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
77+
foreach(copyto!, ty, new_ty)
78+
return y, ty
79+
end
80+
81+
function DI.pushforward!(
82+
f::F,
83+
ty::NTuple,
84+
prep::MooncakeOneArgPushforwardPrep,
85+
backend::AutoMooncakeForward,
86+
x,
87+
tx::NTuple,
88+
contexts::Vararg{DI.Context,C};
89+
) where {F,C}
90+
DI.check_prep(f, prep, backend, x, tx, contexts...)
91+
DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)
92+
return ty
93+
end
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
## Pushforward
2+
3+
struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG}
4+
_sig::Val{SIG}
5+
cache::Tcache
6+
dx_righttype::DX
7+
dy_righttype::DY
8+
end
9+
10+
function DI.prepare_pushforward_nokwarg(
11+
strict::Val,
12+
f!::F,
13+
y,
14+
backend::AutoMooncakeForward,
15+
x,
16+
tx::NTuple,
17+
contexts::Vararg{DI.Context,C};
18+
) where {F,C}
19+
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
20+
config = get_config(backend)
21+
cache = prepare_derivative_cache(
22+
f!,
23+
y,
24+
x,
25+
map(DI.unwrap, contexts)...;
26+
config.debug_mode,
27+
config.silence_debug_messages,
28+
)
29+
dx_righttype = zero_tangent(x)
30+
dy_righttype = zero_tangent(y)
31+
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype)
32+
return prep
33+
end
34+
35+
function DI.value_and_pushforward(
36+
f!::F,
37+
y,
38+
prep::MooncakeTwoArgPushforwardPrep,
39+
backend::AutoMooncakeForward,
40+
x::X,
41+
tx::NTuple,
42+
contexts::Vararg{DI.Context,C};
43+
) where {F,C,X}
44+
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
45+
ty = map(tx) do dx
46+
dx_righttype =
47+
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
48+
y_dual = zero_dual(y)
49+
value_and_derivative!!(
50+
prep.cache,
51+
zero_dual(f!),
52+
y_dual,
53+
Dual(x, dx_righttype),
54+
map(zero_dual DI.unwrap, contexts)...,
55+
)
56+
dy = _copy_output(tangent(y_dual))
57+
return dy
58+
end
59+
return y, ty
60+
end
61+
62+
function DI.pushforward(
63+
f!::F,
64+
y,
65+
prep::MooncakeTwoArgPushforwardPrep,
66+
backend::AutoMooncakeForward,
67+
x,
68+
tx::NTuple,
69+
contexts::Vararg{DI.Context,C};
70+
) where {F,C}
71+
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
72+
return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2]
73+
end
74+
75+
function DI.value_and_pushforward!(
76+
f!::F,
77+
y::Y,
78+
ty::NTuple,
79+
prep::MooncakeTwoArgPushforwardPrep,
80+
backend::AutoMooncakeForward,
81+
x::X,
82+
tx::NTuple,
83+
contexts::Vararg{DI.Context,C};
84+
) where {F,C,X,Y}
85+
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
86+
foreach(tx, ty) do dx, dy
87+
dx_righttype =
88+
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
89+
dy_righttype =
90+
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
91+
value_and_derivative!!(
92+
prep.cache,
93+
zero_dual(f!),
94+
Dual(y, dy_righttype),
95+
Dual(x, dx_righttype),
96+
map(zero_dual DI.unwrap, contexts)...,
97+
)
98+
dy === dy_righttype || copyto!(dy, dy_righttype)
99+
end
100+
return y, ty
101+
end
102+
103+
function DI.pushforward!(
104+
f!::F,
105+
y,
106+
ty::NTuple,
107+
prep::MooncakeTwoArgPushforwardPrep,
108+
backend::AutoMooncakeForward,
109+
x,
110+
tx::NTuple,
111+
contexts::Vararg{DI.Context,C};
112+
) where {F,C}
113+
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
114+
DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)
115+
return ty
116+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using ADTypes:
2828
AutoForwardDiff,
2929
AutoGTPSA,
3030
AutoMooncake,
31+
AutoMooncakeForward,
3132
AutoPolyesterForwardDiff,
3233
AutoReverseDiff,
3334
AutoSymbolics,
@@ -115,6 +116,7 @@ export AutoFiniteDifferences
115116
export AutoForwardDiff
116117
export AutoGTPSA
117118
export AutoMooncake
119+
export AutoMooncakeForward
118120
export AutoPolyesterForwardDiff
119121
export AutoReverseDiff
120122
export AutoSymbolics

DifferentiationInterface/src/misc/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be
1313
1414
!!! warning
1515
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
16-
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), [Mooncake](https://github.com/chalk-lab/Mooncake.jl) or automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
16+
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
1717
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
1818
1919
!!! warning

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ end;
6262
@testset for scen in filter(differentiatewith_scenarios()) do scen
6363
DIT.operator(scen) == :pullback
6464
end
65-
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true)
65+
Mooncake.TestUtils.test_rule(
66+
StableRNG(0), scen.f, scen.x; is_primitive=true, mode=Mooncake.ReverseMode
67+
)
6668
end
6769
end;
6870

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ check_no_implicit_imports(DifferentiationInterface)
1010

1111
LOGGING = get(ENV, "CI", "false") == "false"
1212

13-
backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]
13+
backends = [
14+
AutoMooncake(; config=nothing),
15+
AutoMooncake(; config=Mooncake.Config()),
16+
AutoMooncakeForward(; config=nothing),
17+
]
1418

1519
for backend in backends
1620
@test check_available(backend)

0 commit comments

Comments
 (0)