Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ RecursiveArrayTools = "3.27.2"
Reexport = "1.0"
ReverseDiff = "1.15.1"
SafeTestsets = "0.1.0"
SciMLBase = "2.103.1"
SciMLBase = "2.117.0"
SciMLJacobianOperators = "0.1"
SciMLStructures = "1.3"
SparseArrays = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ include("concrete_solve.jl")
include("second_order.jl")
include("steadystate_adjoint.jl")
include("sde_tools.jl")
include("enzyme_rules.jl")

export extract_local_sensitivities

Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Enzyme rules for VJP choice types defined in SciMLSensitivity
#
# VJP choice types configure how jacobian-vector products are computed within
# sensitivity algorithms. They should be treated as inactive (constant) during
# Enzyme differentiation to prevent errors when they are stored in problem
# structures or other data that Enzyme differentiates through.
#
# Note: AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase
# to avoid type piracy.

import Enzyme: EnzymeRules

# VJP choice types should be inactive since they configure computation methods
EnzymeRules.inactive_type(::Type{<:VJPChoice}) = true
72 changes: 72 additions & 0 deletions test/enzyme_vjp_inactive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Test, SciMLSensitivity, Enzyme, OrdinaryDiffEq

# Test that VJP choice types are treated as inactive by Enzyme
# The AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase
# This addresses issue #1225 where sensealg in ODEProblem constructor would fail

@testset "Enzyme VJP Choice Inactive Types" begin

# Test 1: Basic test that VJP objects can be stored in data structures during Enzyme differentiation
@testset "VJP types in data structures" begin
vjp = EnzymeVJP()

function test_func(x)
# Store the VJP in a data structure (this would fail without inactive rules)
data = (value = x[1] + x[2], vjp = vjp)
return data.value * 2.0
end

x = [1.0, 2.0]
dx = Enzyme.make_zero(x)

# This should not throw an error
@test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx))
@test dx ≈ [2.0, 2.0]
end

# Test 2: Test different VJP choice types are inactive
@testset "Different VJP types inactive" begin
vjp_types = [EnzymeVJP(), ZygoteVJP(), ReverseDiffVJP(), TrackerVJP()]

for vjp in vjp_types
function test_func(x)
data = (value = x[1] * x[2], vjp = vjp)
return data.value + 1.0
end

x = [2.0, 3.0]
dx = Enzyme.make_zero(x)

@test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx))
end
end

# Test 3: Test sensitivity algorithms with VJP choices (integration test)
# Note: This test also depends on SciMLBase having AbstractSensitivityAlgorithm as inactive
@testset "Sensitivity algorithms with VJP choices" begin
function f(du, u, p, t)
du[1] = -p[1] * u[1]
du[2] = p[2] * u[2]
end

function loss_func(p)
u0 = [1.0, 2.0]
# Both VJP choice and sensitivity algorithm should be inactive
prob = ODEProblem(
f, u0, (0.0, 0.1), p, sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()))
sol = solve(prob, Tsit5())
return sol.u[end][1] + sol.u[end][2]
end

p = [0.5, 1.5]
dp = Enzyme.make_zero(p)

# This should not throw the "Error handling recursive stores for String" error
# This is the original failing case from issue #1225
@test_nowarn Enzyme.autodiff(Enzyme.Reverse, loss_func, Enzyme.Active, Enzyme.Duplicated(p, dp))

# Verify the gradient is computed (non-zero and finite)
@test all(isfinite, dp)
@test any(x -> abs(x) > 1e-10, dp) # At least one component should be non-trivial
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ end
@time @safetestset "Scalar u0" include("scalar_u.jl")
@time @safetestset "Error Messages" include("error_messages.jl")
@time @safetestset "Autodiff Events" include("autodiff_events.jl")
@time @safetestset "Enzyme VJP Inactive" include("enzyme_vjp_inactive.jl")
end
end

Expand Down
Loading