Skip to content

Commit bf17102

Browse files
authored
fix: upgrade Mooncake compat to v0.5 (#961)
1 parent a266e80 commit bf17102

File tree

10 files changed

+93
-94
lines changed

10 files changed

+93
-94
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...main)
8+
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...main)
9+
10+
## [0.7.16](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...DifferentiationInterface-v0.7.16)
11+
12+
### Fixed
13+
14+
- Upgrade Mooncake compat to v0.5 ([#961](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/961))
915

1016
## [0.7.15](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.14...DifferentiationInterface-v0.7.15)
1117

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.15"
4+
version = "0.7.16"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1"
6969
GPUArraysCore = "0.2"
7070
GTPSA = "1.4.0"
7171
LinearAlgebra = "1"
72-
Mooncake = "0.4.175"
72+
Mooncake = "0.5.1"
7373
PolyesterForwardDiff = "0.1.2"
7474
ReverseDiff = "1.15.1"
7575
SparseArrays = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ using Mooncake:
2929
NoRData,
3030
primal,
3131
_copy_output,
32-
_copy_to_output!!
32+
_copy_to_output!!,
33+
tangent_to_primal!!
3334

3435
const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}
3536

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
## Pushforward
22

3-
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG}
3+
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6-
dx_righttype::DX
76
df::FT
87
context_tangents::CT
98
end
@@ -18,13 +17,10 @@ function DI.prepare_pushforward_nokwarg(
1817
) where {F, C}
1918
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
2019
config = get_config(backend)
21-
cache = prepare_derivative_cache(
22-
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
23-
)
24-
dx_righttype = zero_tangent(x)
25-
df = zero_tangent(f)
20+
cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config)
21+
df = zero_tangent_or_primal(f, backend)
2622
context_tangents = map(zero_tangent_unwrap, contexts)
27-
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents)
23+
prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents)
2824
return prep
2925
end
3026

@@ -38,19 +34,17 @@ function DI.value_and_pushforward(
3834
) where {F, C, X}
3935
DI.check_prep(f, prep, backend, x, tx, contexts...)
4036
ys_and_ty = map(tx) do dx
41-
dx_righttype =
42-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
43-
y_dual = value_and_derivative!!(
37+
y_and_dy = value_and_derivative!!(
4438
prep.cache,
45-
Dual(f, prep.df),
46-
Dual(x, dx_righttype),
47-
map(Dual_unwrap, contexts, prep.context_tangents)...,
39+
(f, prep.df),
40+
(x, dx),
41+
map(first_unwrap, contexts, prep.context_tangents)...,
4842
)
49-
y = primal(y_dual)
50-
dy = _copy_output(tangent(y_dual))
43+
y = first(y_and_dy)
44+
dy = _copy_output(last(y_and_dy))
5145
return y, dy
5246
end
53-
y = first(ys_and_ty[1])
47+
y = _copy_output(first(ys_and_ty[1]))
5448
ty = map(last, ys_and_ty)
5549
return y, ty
5650
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
## Pushforward
22

3-
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
3+
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT0, FT, YT, CT} <: DI.PushforwardPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6-
dx_righttype::DX
7-
dy_righttype::DY
6+
dcall::FT0
87
df!::FT
8+
dy::YT
99
context_tangents::CT
1010
end
1111

@@ -21,18 +21,18 @@ function DI.prepare_pushforward_nokwarg(
2121
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
2222
config = get_config(backend)
2323
cache = prepare_derivative_cache(
24+
call_and_return,
2425
f!,
2526
y,
2627
x,
2728
map(DI.unwrap, contexts)...;
28-
config.debug_mode,
29-
config.silence_debug_messages,
29+
config
3030
)
31-
dx_righttype = zero_tangent(x)
32-
dy_righttype = zero_tangent(y)
33-
df! = zero_tangent(f!)
31+
dcall = zero_tangent_or_primal(call_and_return, backend)
32+
df! = zero_tangent_or_primal(f!, backend)
33+
dy = zero_tangent_or_primal(y, backend)
3434
context_tangents = map(zero_tangent_unwrap, contexts)
35-
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
35+
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dcall, df!, dy, context_tangents)
3636
return prep
3737
end
3838

@@ -47,18 +47,15 @@ function DI.value_and_pushforward(
4747
) where {F, C, X}
4848
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
4949
ty = map(tx) do dx
50-
dx_righttype =
51-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
52-
y_dual = zero_dual(y)
53-
value_and_derivative!!(
50+
_, new_dy = value_and_derivative!!(
5451
prep.cache,
55-
Dual(f!, prep.df!),
56-
y_dual,
57-
Dual(x, dx_righttype),
58-
map(Dual_unwrap, contexts, prep.context_tangents)...,
52+
(call_and_return, prep.dcall),
53+
(f!, prep.df!),
54+
(y, prep.dy),
55+
(x, dx),
56+
map(first_unwrap, contexts, prep.context_tangents)...,
5957
)
60-
dy = _copy_output(tangent(y_dual))
61-
return dy
58+
return _copy_output(new_dy)
6259
end
6360
return y, ty
6461
end
@@ -88,18 +85,15 @@ function DI.value_and_pushforward!(
8885
) where {F, C, X, Y}
8986
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
9087
foreach(tx, ty) do dx, dy
91-
dx_righttype =
92-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
93-
dy_righttype =
94-
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
95-
value_and_derivative!!(
88+
_, new_dy = value_and_derivative!!(
9689
prep.cache,
97-
Dual(f!, prep.df!),
98-
Dual(y, dy_righttype),
99-
Dual(x, dx_righttype),
100-
map(Dual_unwrap, contexts, prep.context_tangents)...,
90+
(call_and_return, prep.dcall),
91+
(f!, prep.df!),
92+
(y, dy),
93+
(x, dx),
94+
map(first_unwrap, contexts, prep.context_tangents)...,
10195
)
102-
dy === dy_righttype || copyto!(dy, dy_righttype)
96+
copyto!(dy, new_dy)
10397
end
10498
return y, ty
10599
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
## Pullback
22

3-
struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
3+
struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6-
dy_righttype::DY
76
args_to_zero::NTuple{N, Bool}
87
end
98

@@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg(
1211
) where {F, C}
1312
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
1413
config = get_config(backend)
15-
cache = prepare_pullback_cache(
16-
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
17-
)
18-
y = f(x, map(DI.unwrap, contexts)...)
19-
dy_righttype = zero_tangent(y)
14+
cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config)
2015
contexts_tup_false = map(_ -> false, contexts)
2116
args_to_zero = (
2217
false, # f
2318
true, # x
2419
contexts_tup_false...,
2520
)
26-
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
21+
prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero)
2722
return prep
2823
end
2924

@@ -37,10 +32,8 @@ function DI.value_and_pullback(
3732
) where {F, Y, C}
3833
DI.check_prep(f, prep, backend, x, ty, contexts...)
3934
dy = only(ty)
40-
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
4135
new_y, (_, new_dx) = value_and_pullback!!(
42-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
43-
prep.args_to_zero
36+
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
4437
)
4538
return new_y, (_copy_output(new_dx),)
4639
end
@@ -55,11 +48,8 @@ function DI.value_and_pullback(
5548
) where {F, Y, C}
5649
DI.check_prep(f, prep, backend, x, ty, contexts...)
5750
ys_and_tx = map(ty) do dy
58-
dy_righttype =
59-
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
6051
y, (_, new_dx) = value_and_pullback!!(
61-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
62-
prep.args_to_zero
52+
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
6353
)
6454
y, _copy_output(new_dx)
6555
end
@@ -121,9 +111,7 @@ function DI.prepare_gradient_nokwarg(
121111
) where {F, C}
122112
_sig = DI.signature(f, backend, x, contexts...; strict)
123113
config = get_config(backend)
124-
cache = prepare_gradient_cache(
125-
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
126-
)
114+
cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config)
127115
contexts_tup_false = map(_ -> false, contexts)
128116
args_to_zero = (
129117
false, # f

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
1+
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
22
_sig::Val{SIG}
33
cache::Tcache
4-
dy_righttype::DY
5-
target_function::F
4+
dy_backup::DY
65
args_to_zero::NTuple{N, Bool}
76
end
87

@@ -16,31 +15,26 @@ function DI.prepare_pullback_nokwarg(
1615
contexts::Vararg{DI.Context, C}
1716
) where {F, C}
1817
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
19-
target_function = function (f!, y, x, contexts...)
20-
f!(y, x, contexts...)
21-
return y
22-
end
2318
config = get_config(backend)
2419
cache = prepare_pullback_cache(
25-
target_function,
20+
call_and_return,
2621
f!,
2722
y,
2823
x,
2924
map(DI.unwrap, contexts)...;
30-
debug_mode = config.debug_mode,
31-
silence_debug_messages = config.silence_debug_messages,
25+
config,
3226
)
33-
dy_righttype_after = zero_tangent(y)
27+
dy_backup = zero_tangent_or_primal(y, backend)
3428
contexts_tup_false = map(_ -> false, contexts)
3529
args_to_zero = (
36-
false, # target_function
30+
false, # call_and_return
3731
false, # f!
3832
false, # y
3933
true, # x
4034
contexts_tup_false...,
4135
)
4236
prep = MooncakeTwoArgPullbackPrep(
43-
_sig, cache, dy_righttype_after, target_function, args_to_zero
37+
_sig, cache, dy_backup, args_to_zero
4438
)
4539
return prep
4640
end
@@ -57,12 +51,12 @@ function DI.value_and_pullback(
5751
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
5852
dy = only(ty)
5953
# Prepare cotangent to add after the forward pass.
60-
dy_righttype_after = copyto!(prep.dy_righttype, dy)
54+
dy_backup = copyto!(prep.dy_backup, dy)
6155
# Run the reverse-pass and return the results.
6256
y_after, (_, _, _, dx) = value_and_pullback!!(
6357
prep.cache,
64-
dy_righttype_after,
65-
prep.target_function,
58+
dy_backup,
59+
call_and_return,
6660
f!,
6761
y,
6862
x,
@@ -84,11 +78,11 @@ function DI.value_and_pullback(
8478
) where {F, C}
8579
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
8680
tx = map(ty) do dy
87-
dy_righttype_after = copyto!(prep.dy_righttype, dy)
81+
dy_backup = copyto!(prep.dy_backup, dy)
8882
y_after, (_, _, _, dx) = value_and_pullback!!(
8983
prep.cache,
90-
dy_righttype_after,
91-
prep.target_function,
84+
dy_backup,
85+
call_and_return,
9286
f!,
9387
y,
9488
x,

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,18 @@ get_config(::AnyAutoMooncake{Nothing}) = Config()
22
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
33

44
@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
5-
@inline Dual_unwrap(c, dc) = Dual(DI.unwrap(c), dc)
5+
@inline first_unwrap(c, dc) = (DI.unwrap(c), dc)
6+
7+
function call_and_return(f!::F, y, x, contexts...) where {F}
8+
f!(y, x, contexts...)
9+
return y
10+
end
11+
12+
function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
13+
if get_config(backend).friendly_tangents
14+
# zero(x) but safer
15+
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
16+
else
17+
return zero_tangent(x)
18+
end
19+
end

DifferentiationInterface/test/Back/Mooncake/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
66
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
77
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
88
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
9+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)