Skip to content

Commit 374bf32

Browse files
committed
Test DifferentiationInterface
1 parent c709fe1 commit 374bf32

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
ThickNumbersForwardDiffExt = "ForwardDiff"
1414

1515
[compat]
16+
DifferentiationInterface = "0.6"
1617
ForwardDiff = "0.10, 1"
1718
LinearAlgebra = "1"
1819
julia = "1.9"
1920

2021
[extras]
22+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
2123
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2224
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2325

2426
[targets]
25-
test = ["ForwardDiff", "Test"]
27+
test = ["DifferentiationInterface", "ForwardDiff", "Test"]

ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ end
1515
Interval(lo, hi) = Interval(promote(lo, hi)...)
1616
Interval{T}(iv::Interval) where T = Interval{T}(iv.lo, iv.hi)
1717
Interval{T}(x::Number) where T = Interval{T}(x, x)
18+
Interval{T}(nt::@NamedTuple{lo::T, hi::T}) where T = Interval{T}(nt.lo, nt.hi) # needed by Mooncake
1819

1920
ThickNumbers.loval(x::Interval) = x.lo
2021
ThickNumbers.hival(x::Interval) = x.hi
@@ -44,4 +45,6 @@ end
4445
Base.abs2(x::Interval) = Interval(mig(x)^2, mag(x)^2)
4546
Base.sqrt(x::Interval) = Interval(sqrt(loval(x)), sqrt(hival(x)))
4647

48+
Base.conj(x::Interval{T}) where T = Interval{T}(conj(x.lo), conj(x.hi)) # needed for Enzyme
49+
4750
end # module IntervalArith

test/extensions/di.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using ThickNumbers
2+
using DifferentiationInterface
3+
using ForwardDiff
4+
5+
using Test
6+
7+
include(joinpath(dirname(@__DIR__), "setpath.jl"))
8+
9+
using IntervalArith
10+
11+
@testset "DifferentiationInterface" begin
12+
@test isempty(detect_ambiguities(ThickNumbers))
13+
@test isempty(detect_ambiguities(IntervalArith))
14+
15+
for backend in (AutoForwardDiff(), )
16+
a, b = Interval(1, 2), Interval(0, 0.1)
17+
f1(t) = a + t*b
18+
f2(x) = a + abs2(x)/2
19+
20+
df1(t) = derivative(f1, backend, t)
21+
df2(x) = derivative(f2, backend, x)
22+
@test df1(0.5) b
23+
@test df2(b) b
24+
ddf2(x) = derivative(df2, backend, x)
25+
@test ddf2(b) 1
26+
27+
# abs
28+
dabs(x) = derivative(abs, backend, x)
29+
ddabs(x) = derivative(dabs, backend, x)
30+
dddabs(x) = derivative(ddabs, backend, x)
31+
@test dabs(Interval(1.0, 2.0)) === Interval(1.0, 1.0)
32+
@test ddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
33+
@test dddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
34+
@test dabs(Interval(-1.0, 2.0)) === Interval(-1.0, 1.0)
35+
@test ddabs(Interval(-1.0, 2.0)) === Interval(0.0, Inf)
36+
abs3 = dddabs(Interval(-1.0, 2.0))
37+
@test abs3 === Interval(-Inf, Inf) || isnan_tn(abs3)
38+
end
39+
end

test/extensions/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
include("forwarddiff.jl")
2+
include("di.jl")

0 commit comments

Comments
 (0)