Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@ authors = ["Vedant Puri <[email protected]>"]
version = "1.6.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
SciMLOperatorsChainRulesCoreExt = "ChainRulesCore"
SciMLOperatorsSparseArraysExt = "SparseArrays"
SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore"

[compat]
Accessors = "0.1.42"
ArrayInterface = "7.19"
ChainRulesCore = "1.26.0"
DocStringExtensions = "0.9.4"
LinearAlgebra = "1.10"
MacroTools = "0.5.16"
Expand Down
53 changes: 53 additions & 0 deletions ext/SciMLOperatorsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module SciMLOperatorsChainRulesCoreExt

using SciMLOperators
using ChainRulesCore
import SciMLOperators: ScaledOperator, ScalarOperator, AbstractSciMLOperator

"""
Fix for gradient double-counting issue in ScaledOperator constructor.

The issue: When creating ScaledOperator(λ, L) where λ is a ScalarOperator with parameter
dependencies, Zygote was double-counting gradients because:
1. Gradient flows through the ScalarOperator's creation/value
2. Gradient also flows through the ScalarOperator being stored as a struct field

This rrule ensures gradients are only counted once by carefully managing the pullback
to avoid the structural dependency double-counting.

Fixes issue: https://github.com/SciML/SciMLOperators.jl/issues/305
"""
function ChainRulesCore.rrule(::Type{ScaledOperator}, λ::ScalarOperator, L::AbstractSciMLOperator)
# Forward pass - same as original constructor
result = ScaledOperator(λ, L)

function ScaledOperator_pullback(Ȳ)
# Handle gradients carefully to avoid double-counting for ScalarOperator
# The key insight: gradients should flow through ScalarOperator creation
# but NOT through struct field access

if hasfield(typeof(Ȳ), :λ) && getfield(Ȳ, :λ) isa ChainRulesCore.AbstractTangent
λ_tangent = getfield(Ȳ, :λ)
# For ScalarOperator, only propagate through the value to avoid double-counting
if hasfield(typeof(λ_tangent), :val)
∂λ = ChainRulesCore.Tangent{typeof(λ)}(val=getfield(λ_tangent, :val))
else
∂λ = λ_tangent
end
else
∂λ = NoTangent()
end

if hasfield(typeof(Ȳ), :L) && getfield(Ȳ, :L) isa ChainRulesCore.AbstractTangent
∂L = getfield(Ȳ, :L)
else
∂L = NoTangent()
end

return (NoTangent(), ∂λ, ∂L)
end

return result, ScaledOperator_pullback
end

end # module
119 changes: 119 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Tests for ChainRules extension fixing gradient double-counting issue
# These tests specifically target issue #305

using SciMLOperators
using LinearSolve, Zygote, Test
using SciMLOperators: ScaledOperator

@testset "ChainRules fix for ScalarOperator gradient double-counting" begin
# Test 1: Simple ScaledOperator creation
@testset "Simple ScaledOperator gradient" begin
simple_func = p -> 2.0 * p

# Create ScalarOperator and matrix operator
S = ScalarOperator(0.0, (A, u, p, t) -> simple_func(p))
M = MatrixOperator(ones(2, 2))

# Test that ScaledOperator creation doesn't double-count gradients
function test_scaled(p)
S_val = ScalarOperator(simple_func(p))
scaled = S_val * M
return scaled.λ.val
end

p_val = 0.5
result = test_scaled(p_val)
grad = Zygote.gradient(test_scaled, p_val)[1]

@test result ≈ simple_func(p_val)
@test grad ≈ 2.0 # Should not be doubled (4.0)
end

# Test 2: Full update_coefficients pipeline
@testset "update_coefficients pipeline" begin
exp_func = p -> exp(1 - p)

A1 = MatrixOperator(rand(3, 3))
A2 = MatrixOperator(rand(3, 3))
Func = ScalarOperator(0.0, (A, u, p, t) -> exp_func(p))
A = A1 + Func * A2

# Test that update_coefficients doesn't cause gradient doubling
function test_update(p)
A_updated = update_coefficients(A, 0, p, 0)
# Access the scalar value from the updated composition
scaled_op = A_updated.ops[2] # This should be the ScaledOperator
return scaled_op.λ.val
end

p_val = 0.3
result = test_update(p_val)
grad = Zygote.gradient(test_update, p_val)[1]

@test result ≈ exp_func(p_val)
# Check that gradient matches the derivative of exp_func
expected_grad = -exp(1 - p_val) # derivative of exp(1-p) is -exp(1-p)
@test grad ≈ expected_grad
end

# Test 3: Original MWE from issue #305
@testset "Original MWE from issue #305" begin
a1 = rand(3, 3)
a2 = rand(3, 3)
func = p -> exp(1 - p)
a = p -> a1 + func(p) * a2

A1 = MatrixOperator(a1)
A2 = MatrixOperator(a2)
Func = ScalarOperator(0.0, (A, u, p, t) -> func(p))
A = A1 + Func * A2

b = rand(3)

function sol1(p)
Ap = update_coefficients(A, 0, p, 0) |> concretize
prob = LinearProblem(Ap, b)
sol = solve(prob, KrylovJL_GMRES())
return sum(sol.u)
end

function sol2(p)
Ap = a(p)
prob = LinearProblem(Ap, b)
sol = solve(prob, KrylovJL_GMRES())
return sum(sol.u)
end

p_val = rand()
s1, s2 = sol1(p_val), sol2(p_val)

# Primal solutions should match
@test s1 ≈ s2

grad1 = Zygote.gradient(sol1, p_val)[1]
grad2 = Zygote.gradient(sol2, p_val)[1]

# Gradients should match (no more doubling)
@test grad1 ≈ grad2 rtol=1e-10
@test !(grad1 ≈ 2 * grad2) # Should NOT be doubled anymore
end

# Test 4: Direct ScaledOperator constructor (the specific case our rrule fixes)
@testset "Direct ScaledOperator constructor" begin
func = p -> 3.0 * p

function test_direct_constructor(p)
S = ScalarOperator(func(p))
M = MatrixOperator([2.0 1.0; 1.0 2.0])
scaled = ScaledOperator(S, M) # This should use our rrule
return scaled.λ.val
end

p_val = 0.5
result = test_direct_constructor(p_val)
grad = Zygote.gradient(test_direct_constructor, p_val)[1]

@test result ≈ func(p_val)
@test grad ≈ 3.0 # Should not be doubled (6.0)
end
end
Loading