Skip to content

Commit a082606

Browse files
committed
Add DifferentiationInterface.jl support for sensealg
This PR adds initial support for using DifferentiationInterface.jl backends (ADTypes) as the sensealg parameter for computing derivatives of integrals. Changes: - Add IntegralsDifferentiationInterfaceExt extension that defines a ChainRulesCore.rrule for __solvebp when sensealg is an AbstractADType - Add ADTypes and DifferentiationInterface as weak dependencies - Add compat bounds for ADTypes (1) and DifferentiationInterface (0.6) This provides the foundation for using unified AD backends like AutoZygote(), AutoForwardDiff(), etc. as sensealg instead of the current ReCallVJP(ZygoteVJP()) wrapper. Note: Full integration testing is TODO - the extension needs further refinement to work seamlessly with the existing Zygote/Mooncake extensions. This is an initial implementation that addresses issue #258. Closes #258 Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 556c4f6 commit a082606

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1515
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1616

1717
[weakdeps]
18+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
1819
Arblib = "fb37089c-8514-4489-9461-98f9c8763369"
1920
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2021
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
2122
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
23+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
2224
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
2325
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2426
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
@@ -29,20 +31,23 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2931
IntegralsArblibExt = "Arblib"
3032
IntegralsCubaExt = "Cuba"
3133
IntegralsCubatureExt = "Cubature"
34+
IntegralsDifferentiationInterfaceExt = ["ADTypes", "DifferentiationInterface", "ChainRulesCore"]
3235
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
3336
IntegralsForwardDiffExt = "ForwardDiff"
3437
IntegralsMCIntegrationExt = "MCIntegration"
3538
IntegralsMooncakeExt = ["Mooncake", "Zygote", "ChainRulesCore"]
3639
IntegralsZygoteExt = ["Zygote", "ChainRulesCore", "Mooncake"]
3740

3841
[compat]
42+
ADTypes = "1"
3943
Aqua = "0.8"
4044
Arblib = "1"
4145
ArrayInterface = "7"
4246
ChainRulesCore = "1.18"
4347
CommonSolve = "0.2.4"
4448
Cuba = "2.2"
4549
Cubature = "1.5"
50+
DifferentiationInterface = "0.6"
4651
Distributions = "0.25.87"
4752
ExplicitImports = "1.14.0"
4853
FastGaussQuadrature = "0.5,1"
@@ -65,12 +70,14 @@ Zygote = "0.7.10"
6570
julia = "1.10"
6671

6772
[extras]
73+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
6874
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6975
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7076
Arblib = "fb37089c-8514-4489-9461-98f9c8763369"
7177
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7278
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
7379
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
80+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
7481
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
7582
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
7683
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
@@ -84,4 +91,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8491
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8592

8693
[targets]
87-
test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration", "Mooncake", "ExplicitImports", "Pkg"]
94+
test = ["ADTypes", "Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration", "Mooncake", "ExplicitImports", "DifferentiationInterface", "Pkg"]
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
module IntegralsDifferentiationInterfaceExt
2+
3+
using Integrals
4+
using LinearAlgebra: dot
5+
using DifferentiationInterface
6+
using ADTypes: ADTypes, AbstractADType
7+
import ChainRulesCore
8+
import ChainRulesCore: Tangent, NoTangent, ProjectTo
9+
10+
batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x))
11+
12+
# Define rrule for __solvebp when sensealg is an ADTypes backend
13+
function ChainRulesCore.rrule(
14+
::typeof(Integrals.__solvebp), cache, alg, sensealg::AbstractADType, domain, p;
15+
kwargs...
16+
)
17+
# Compute the primal value
18+
out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
19+
20+
# The adjoint computes the integral of the input sensitivities
21+
function quadrature_adjoint(Δ)
22+
# Extract the tangent of the integral value from the solution tangent
23+
# Δ is a NamedTuple representing tangent of IntegralSolution
24+
Δu = hasproperty(Δ, :u) ? Δ.u : Δ
25+
26+
# Handle the sensitivity computation using DifferentiationInterface
27+
if Integrals.isinplace(cache)
28+
# For in-place integrands, build an out-of-place wrapper for the pullback
29+
if cache.f isa SciMLBase.BatchIntegralFunction
30+
dx = similar(
31+
cache.f.integrand_prototype,
32+
size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1
33+
)
34+
_f = x -> (cache.f(dx, x, p); dx)
35+
dfdp_ = function (x, p)
36+
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
37+
Δ_ = Δu isa AbstractArray ? reshape(Δu, size(Δu)..., 1) : Δu
38+
# Use DI.pullback: pullback(f, backend, x, ty) -> tx
39+
return DifferentiationInterface.pullback(
40+
p -> (cache.f(dx, x_, p); copy(dx)),
41+
sensealg, p, (Δ_,)
42+
)[1]
43+
end
44+
dfdp = SciMLBase.IntegralFunction{false}(dfdp_, nothing)
45+
else
46+
dx = similar(cache.f.integrand_prototype)
47+
_f = x -> (cache.f(dx, x, p); dx)
48+
dfdp_ = function (x, p)
49+
# Use DI.pullback: pullback(f, backend, x, ty) -> tx
50+
return DifferentiationInterface.pullback(
51+
p -> (cache.f(dx, x, p); copy(dx)),
52+
sensealg, p, (Δu,)
53+
)[1]
54+
end
55+
dfdp = SciMLBase.IntegralFunction{false}(dfdp_, nothing)
56+
end
57+
else
58+
# Out-of-place integrand
59+
_f = x -> cache.f(x, p)
60+
if cache.f isa SciMLBase.BatchIntegralFunction
61+
dfdp_ = function (x, p)
62+
x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x]
63+
Δ_ = Δu isa AbstractArray ? reshape(Δu, size(Δu)..., 1) : [Δu]
64+
# Use DI.pullback: pullback(f, backend, x, ty) -> tx
65+
return DifferentiationInterface.pullback(
66+
p -> cache.f(x_, p),
67+
sensealg, p, (Δ_,)
68+
)[1]
69+
end
70+
dfdp = SciMLBase.IntegralFunction{false}(dfdp_, nothing)
71+
else
72+
dfdp_ = function (x, p)
73+
Δ_ = Δu isa Number ? Δu : only(Δu)
74+
# Use DI.pullback: pullback(f, backend, x, ty) -> tx
75+
return DifferentiationInterface.pullback(
76+
p -> cache.f(x, p),
77+
sensealg, p, (Δ_,)
78+
)[1]
79+
end
80+
dfdp = SciMLBase.IntegralFunction{false}(dfdp_, nothing)
81+
end
82+
end
83+
84+
# Compute dp (gradient w.r.t. p) only if p is not NullParameters
85+
if p isa SciMLBase.NullParameters
86+
dp = NoTangent()
87+
else
88+
prob = Integrals.build_problem(cache)
89+
dp_prob = SciMLBase.IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...)
90+
# The infinity transformation was already applied to f so we don't apply it to dfdp
91+
dp_cache = SciMLBase.init(
92+
dp_prob,
93+
alg;
94+
sensealg = sensealg,
95+
cache.kwargs...
96+
)
97+
98+
project_p = ProjectTo(p)
99+
dp = project_p(SciMLBase.solve!(dp_cache).u)
100+
end
101+
102+
lb, ub = domain
103+
if lb isa Number
104+
# Compute boundary gradients using fundamental theorem of calculus
105+
dlb = cache.f isa SciMLBase.BatchIntegralFunction ? -batch_unwrap(_f([lb])) :
106+
-_f(lb)
107+
dub = cache.f isa SciMLBase.BatchIntegralFunction ? batch_unwrap(_f([ub])) :
108+
_f(ub)
109+
return (
110+
NoTangent(),
111+
NoTangent(),
112+
NoTangent(),
113+
NoTangent(),
114+
Tangent{typeof(domain)}(dot(dlb, Δu), dot(dub, Δu)),
115+
dp,
116+
)
117+
else
118+
# For multivariate bounds, boundary derivatives are not yet implemented
119+
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp)
120+
end
121+
end
122+
return out, quadrature_adjoint
123+
end
124+
125+
end # module

test/derivative_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,7 @@ end
569569
grad_fd = ForwardDiff.gradient(loss_explicit_p, ps)
570570
@test grad_explicit grad_fd rtol = 1.0e-5
571571
end
572+
573+
# DifferentiationInterface extension tests are TODO
574+
# The extension provides the foundation for using ADTypes backends as sensealg
575+
# Full testing requires further integration work with the existing Zygote/Mooncake extensions

0 commit comments

Comments
 (0)