diff --git a/Project.toml b/Project.toml index d65787be7..af644251a 100644 --- a/Project.toml +++ b/Project.toml @@ -37,6 +37,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -55,6 +56,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" +SciMLBaseDifferentiationInterfaceExt = "DifferentiationInterface" SciMLBaseDistributionsExt = "Distributions" SciMLBaseEnzymeExt = "Enzyme" SciMLBaseForwardDiffExt = "ForwardDiff" @@ -80,6 +82,7 @@ ChainRules = "1.58.0" ChainRulesCore = "1.18" CommonSolve = "0.2.4" ConstructionBase = "1.5" +DifferentiationInterface = "0.6, 0.7" Distributed = "1.10" Distributions = "0.25" DocStringExtensions = "0.9" diff --git a/ext/SciMLBaseDifferentiationInterfaceExt.jl b/ext/SciMLBaseDifferentiationInterfaceExt.jl new file mode 100644 index 000000000..cd491ca11 --- /dev/null +++ b/ext/SciMLBaseDifferentiationInterfaceExt.jl @@ -0,0 +1,18 @@ +module SciMLBaseDifferentiationInterfaceExt + +using SciMLBase, DifferentiationInterface + +import SciMLBase: anyeltypedual + +# Opt out since these are using for preallocation, not differentiation +function anyeltypedual(x::DifferentiationInterface.Prep, + ::Type{Val{counter}} = Val{0}) where {counter} + Any +end +function anyeltypedual(x::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + DifferentiationInterface.Prep} + Any +end + +end