Skip to content

Commit 7655631

Browse files
committed
Test DI, support Mooncake
1 parent c709fe1 commit 7655631

File tree

5 files changed

+64
-1
lines changed

5 files changed

+64
-1
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

99
[weakdeps]
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
11+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1112

1213
[extensions]
1314
ThickNumbersForwardDiffExt = "ForwardDiff"
15+
ThickNumbersMooncakeExt = "Mooncake"
1416

1517
[compat]
18+
DifferentiationInterface = "0.6"
1619
ForwardDiff = "0.10, 1"
1720
LinearAlgebra = "1"
21+
Mooncake = "0.4"
1822
julia = "1.9"
1923

2024
[extras]
25+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
2126
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
27+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2228
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2329

2430
[targets]
25-
test = ["ForwardDiff", "Test"]
31+
test = ["DifferentiationInterface", "ForwardDiff", "Mooncake", "Test"]

ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ end
1515
Interval(lo, hi) = Interval(promote(lo, hi)...)
1616
Interval{T}(iv::Interval) where T = Interval{T}(iv.lo, iv.hi)
1717
Interval{T}(x::Number) where T = Interval{T}(x, x)
18+
Interval{T}(nt::@NamedTuple{lo::T, hi::T}) where T = Interval{T}(nt.lo, nt.hi) # needed by Mooncake
1819

1920
ThickNumbers.loval(x::Interval) = x.lo
2021
ThickNumbers.hival(x::Interval) = x.hi
@@ -44,4 +45,6 @@ end
4445
Base.abs2(x::Interval) = Interval(mig(x)^2, mag(x)^2)
4546
Base.sqrt(x::Interval) = Interval(sqrt(loval(x)), sqrt(hival(x)))
4647

48+
Base.conj(x::Interval{T}) where T = Interval{T}(conj(x.lo), conj(x.hi)) # needed for Enzyme
49+
4750
end # module IntervalArith

ext/ThickNumbersMooncakeExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module ThickNumbersMooncakeExt
2+
3+
using ThickNumbers
4+
using Mooncake
5+
6+
Mooncake.tangent_type(::Type{TN}) where TN<:ThickNumber = TN
7+
Mooncake.fdata_type(::Type{TN}) where TN<:ThickNumber = Mooncake.NoFData
8+
Mooncake.rdata_type(::Type{TN}) where TN<:ThickNumber = TN
9+
Mooncake.zero_rdata(x::ThickNumber) = zero(x)
10+
Mooncake.increment_internal!!(::Mooncake.IncCache, x::TN, y::TN) where TN<:ThickNumber = x + y
11+
12+
end

test/extensions/di.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using ThickNumbers
2+
using DifferentiationInterface
3+
using ForwardDiff
4+
# using Enzyme: EnzymeCore
5+
using Mooncake
6+
7+
using Test
8+
9+
include(joinpath(dirname(@__DIR__), "setpath.jl"))
10+
11+
using IntervalArith
12+
13+
@testset "DifferentiationInterface" begin
14+
@test isempty(detect_ambiguities(ThickNumbers))
15+
@test isempty(detect_ambiguities(IntervalArith))
16+
17+
for backend in (AutoForwardDiff(), #=AutoEnzyme(mode=EnzymeCore.Forward), AutoEnzyme(mode=EnzymeCore.Reverse),=# AutoMooncake(config=nothing))
18+
a, b = Interval(1, 2), Interval(0, 0.1)
19+
f1(t) = a + t*b
20+
f2(x) = a + abs2(x)/2
21+
22+
df1(t) = derivative(f1, backend, t)
23+
df2(x) = derivative(f2, backend, x)
24+
@test df1(0.5) b
25+
@test df2(b) b
26+
ddf2(x) = derivative(df2, backend, x)
27+
@test ddf2(b) 1
28+
29+
# abs
30+
dabs(x) = derivative(abs, backend, x)
31+
ddabs(x) = derivative(dabs, backend, x)
32+
dddabs(x) = derivative(ddabs, backend, x)
33+
@test dabs(Interval(1.0, 2.0)) === Interval(1.0, 1.0)
34+
@test ddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
35+
@test dddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
36+
@test dabs(Interval(-1.0, 2.0)) === Interval(-1.0, 1.0)
37+
@test ddabs(Interval(-1.0, 2.0)) === Interval(0.0, Inf)
38+
abs3 = dddabs(Interval(-1.0, 2.0))
39+
@test abs3 === Interval(-Inf, Inf) || isnan_tn(abs3)
40+
end
41+
end

test/extensions/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
include("forwarddiff.jl")
2+
include("di.jl")

0 commit comments

Comments
 (0)