diff --git a/Project.toml b/Project.toml index 623193e..d1cdc2c 100644 --- a/Project.toml +++ b/Project.toml @@ -13,13 +13,15 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ThickNumbersForwardDiffExt = "ForwardDiff" [compat] +DifferentiationInterface = "0.6" ForwardDiff = "0.10, 1" LinearAlgebra = "1" julia = "1.9" [extras] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ForwardDiff", "Test"] +test = ["DifferentiationInterface", "ForwardDiff", "Test"] diff --git a/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl b/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl index c7cd510..d32cc9e 100644 --- a/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl +++ b/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl @@ -44,4 +44,6 @@ end Base.abs2(x::Interval) = Interval(mig(x)^2, mag(x)^2) Base.sqrt(x::Interval) = Interval(sqrt(loval(x)), sqrt(hival(x))) +Base.conj(x::Interval{T}) where T = Interval{T}(conj(x.lo), conj(x.hi)) # needed for Enzyme + end # module IntervalArith diff --git a/test/extensions/di.jl b/test/extensions/di.jl new file mode 100644 index 0000000..f059037 --- /dev/null +++ b/test/extensions/di.jl @@ -0,0 +1,39 @@ +using ThickNumbers +using DifferentiationInterface +using ForwardDiff + +using Test + +include(joinpath(dirname(@__DIR__), "setpath.jl")) + +using IntervalArith + +@testset "DifferentiationInterface" begin + @test isempty(detect_ambiguities(ThickNumbers)) + @test isempty(detect_ambiguities(IntervalArith)) + + for backend in (AutoForwardDiff(), ) + a, b = Interval(1, 2), Interval(0, 0.1) + f1(t) = a + t*b + f2(x) = a + abs2(x)/2 + + df1(t) = derivative(f1, backend, t) + df2(x) = derivative(f2, backend, x) + @test df1(0.5) ≐ b + @test df2(b) ⩪ b + ddf2(x) = derivative(df2, backend, x) + @test ddf2(b) ≐ 1 + + # abs + dabs(x) = derivative(abs, backend, x) + ddabs(x) = derivative(dabs, backend, x) + dddabs(x) = derivative(ddabs, backend, x) + @test dabs(Interval(1.0, 2.0)) === Interval(1.0, 1.0) + @test ddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0) + @test dddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0) + @test dabs(Interval(-1.0, 2.0)) === Interval(-1.0, 1.0) + @test ddabs(Interval(-1.0, 2.0)) === Interval(0.0, Inf) + abs3 = dddabs(Interval(-1.0, 2.0)) + @test abs3 === Interval(-Inf, Inf) || isnan_tn(abs3) + end +end diff --git a/test/extensions/runtests.jl b/test/extensions/runtests.jl index 07629ba..0305422 100644 --- a/test/extensions/runtests.jl +++ b/test/extensions/runtests.jl @@ -1 +1,2 @@ include("forwarddiff.jl") +include("di.jl")