diff --git a/Project.toml b/Project.toml index d1cdc2c..55acbb2 100644 --- a/Project.toml +++ b/Project.toml @@ -8,20 +8,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] ThickNumbersForwardDiffExt = "ForwardDiff" +ThickNumbersMooncakeExt = "Mooncake" [compat] DifferentiationInterface = "0.6" ForwardDiff = "0.10, 1" LinearAlgebra = "1" +Mooncake = "0.4" julia = "1.9" [extras] DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DifferentiationInterface", "ForwardDiff", "Test"] +test = ["DifferentiationInterface", "ForwardDiff", "Mooncake", "Test"] diff --git a/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl b/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl index d32cc9e..2139a33 100644 --- a/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl +++ b/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl @@ -15,6 +15,7 @@ end Interval(lo, hi) = Interval(promote(lo, hi)...) Interval{T}(iv::Interval) where T = Interval{T}(iv.lo, iv.hi) Interval{T}(x::Number) where T = Interval{T}(x, x) +Interval{T}(nt::@NamedTuple{lo::T, hi::T}) where T = Interval{T}(nt.lo, nt.hi) # needed by Mooncake ThickNumbers.loval(x::Interval) = x.lo ThickNumbers.hival(x::Interval) = x.hi diff --git a/ext/ThickNumbersMooncakeExt.jl b/ext/ThickNumbersMooncakeExt.jl new file mode 100644 index 0000000..1599e15 --- /dev/null +++ b/ext/ThickNumbersMooncakeExt.jl @@ -0,0 +1,12 @@ +module ThickNumbersMooncakeExt + +using ThickNumbers +using Mooncake + +Mooncake.tangent_type(::Type{TN}) where TN<:ThickNumber = TN +Mooncake.fdata_type(::Type{TN}) where TN<:ThickNumber = Mooncake.NoFData +Mooncake.rdata_type(::Type{TN}) where TN<:ThickNumber = TN +Mooncake.zero_rdata(x::ThickNumber) = zero(x) +Mooncake.increment_internal!!(::Mooncake.IncCache, x::TN, y::TN) where TN<:ThickNumber = x + y + +end diff --git a/test/extensions/di.jl b/test/extensions/di.jl index f059037..04e158b 100644 --- a/test/extensions/di.jl +++ b/test/extensions/di.jl @@ -1,6 +1,8 @@ using ThickNumbers using DifferentiationInterface using ForwardDiff +# using Enzyme: EnzymeCore +using Mooncake using Test @@ -12,7 +14,7 @@ using IntervalArith @test isempty(detect_ambiguities(ThickNumbers)) @test isempty(detect_ambiguities(IntervalArith)) - for backend in (AutoForwardDiff(), ) + for backend in (AutoForwardDiff(), #=AutoEnzyme(mode=EnzymeCore.Forward), AutoEnzyme(mode=EnzymeCore.Reverse),=# AutoMooncake(config=nothing)) a, b = Interval(1, 2), Interval(0, 0.1) f1(t) = a + t*b f2(x) = a + abs2(x)/2