diff --git a/Project.toml b/Project.toml index 48765eb2..fd26fbb7 100644 --- a/Project.toml +++ b/Project.toml @@ -4,23 +4,26 @@ authors = ["Vedant Puri "] version = "1.6.0" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [extensions] +SciMLOperatorsChainRulesCoreExt = "ChainRulesCore" SciMLOperatorsSparseArraysExt = "SparseArrays" SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore" [compat] Accessors = "0.1.42" ArrayInterface = "7.19" +ChainRulesCore = "1.26.0" DocStringExtensions = "0.9.4" LinearAlgebra = "1.10" MacroTools = "0.5.16" diff --git a/ext/SciMLOperatorsChainRulesCoreExt.jl b/ext/SciMLOperatorsChainRulesCoreExt.jl new file mode 100644 index 00000000..ec292983 --- /dev/null +++ b/ext/SciMLOperatorsChainRulesCoreExt.jl @@ -0,0 +1,53 @@ +module SciMLOperatorsChainRulesCoreExt + +using SciMLOperators +using ChainRulesCore +import SciMLOperators: ScaledOperator, ScalarOperator, AbstractSciMLOperator + +""" +Fix for gradient double-counting issue in ScaledOperator constructor. + +The issue: When creating ScaledOperator(λ, L) where λ is a ScalarOperator with parameter +dependencies, Zygote was double-counting gradients because: +1. Gradient flows through the ScalarOperator's creation/value +2. Gradient also flows through the ScalarOperator being stored as a struct field + +This rrule ensures gradients are only counted once by carefully managing the pullback +to avoid the structural dependency double-counting. + +Fixes issue: https://github.com/SciML/SciMLOperators.jl/issues/305 +""" +function ChainRulesCore.rrule(::Type{ScaledOperator}, λ::ScalarOperator, L::AbstractSciMLOperator) + # Forward pass - same as original constructor + result = ScaledOperator(λ, L) + + function ScaledOperator_pullback(Ȳ) + # Handle gradients carefully to avoid double-counting for ScalarOperator + # The key insight: gradients should flow through ScalarOperator creation + # but NOT through struct field access + + if hasfield(typeof(Ȳ), :λ) && getfield(Ȳ, :λ) isa ChainRulesCore.AbstractTangent + λ_tangent = getfield(Ȳ, :λ) + # For ScalarOperator, only propagate through the value to avoid double-counting + if hasfield(typeof(λ_tangent), :val) + ∂λ = ChainRulesCore.Tangent{typeof(λ)}(val=getfield(λ_tangent, :val)) + else + ∂λ = λ_tangent + end + else + ∂λ = NoTangent() + end + + if hasfield(typeof(Ȳ), :L) && getfield(Ȳ, :L) isa ChainRulesCore.AbstractTangent + ∂L = getfield(Ȳ, :L) + else + ∂L = NoTangent() + end + + return (NoTangent(), ∂λ, ∂L) + end + + return result, ScaledOperator_pullback +end + +end # module \ No newline at end of file diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 00000000..153afe64 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,119 @@ +# Tests for ChainRules extension fixing gradient double-counting issue +# These tests specifically target issue #305 + +using SciMLOperators +using LinearSolve, Zygote, Test +using SciMLOperators: ScaledOperator + +@testset "ChainRules fix for ScalarOperator gradient double-counting" begin + # Test 1: Simple ScaledOperator creation + @testset "Simple ScaledOperator gradient" begin + simple_func = p -> 2.0 * p + + # Create ScalarOperator and matrix operator + S = ScalarOperator(0.0, (A, u, p, t) -> simple_func(p)) + M = MatrixOperator(ones(2, 2)) + + # Test that ScaledOperator creation doesn't double-count gradients + function test_scaled(p) + S_val = ScalarOperator(simple_func(p)) + scaled = S_val * M + return scaled.λ.val + end + + p_val = 0.5 + result = test_scaled(p_val) + grad = Zygote.gradient(test_scaled, p_val)[1] + + @test result ≈ simple_func(p_val) + @test grad ≈ 2.0 # Should not be doubled (4.0) + end + + # Test 2: Full update_coefficients pipeline + @testset "update_coefficients pipeline" begin + exp_func = p -> exp(1 - p) + + A1 = MatrixOperator(rand(3, 3)) + A2 = MatrixOperator(rand(3, 3)) + Func = ScalarOperator(0.0, (A, u, p, t) -> exp_func(p)) + A = A1 + Func * A2 + + # Test that update_coefficients doesn't cause gradient doubling + function test_update(p) + A_updated = update_coefficients(A, 0, p, 0) + # Access the scalar value from the updated composition + scaled_op = A_updated.ops[2] # This should be the ScaledOperator + return scaled_op.λ.val + end + + p_val = 0.3 + result = test_update(p_val) + grad = Zygote.gradient(test_update, p_val)[1] + + @test result ≈ exp_func(p_val) + # Check that gradient matches the derivative of exp_func + expected_grad = -exp(1 - p_val) # derivative of exp(1-p) is -exp(1-p) + @test grad ≈ expected_grad + end + + # Test 3: Original MWE from issue #305 + @testset "Original MWE from issue #305" begin + a1 = rand(3, 3) + a2 = rand(3, 3) + func = p -> exp(1 - p) + a = p -> a1 + func(p) * a2 + + A1 = MatrixOperator(a1) + A2 = MatrixOperator(a2) + Func = ScalarOperator(0.0, (A, u, p, t) -> func(p)) + A = A1 + Func * A2 + + b = rand(3) + + function sol1(p) + Ap = update_coefficients(A, 0, p, 0) |> concretize + prob = LinearProblem(Ap, b) + sol = solve(prob, KrylovJL_GMRES()) + return sum(sol.u) + end + + function sol2(p) + Ap = a(p) + prob = LinearProblem(Ap, b) + sol = solve(prob, KrylovJL_GMRES()) + return sum(sol.u) + end + + p_val = rand() + s1, s2 = sol1(p_val), sol2(p_val) + + # Primal solutions should match + @test s1 ≈ s2 + + grad1 = Zygote.gradient(sol1, p_val)[1] + grad2 = Zygote.gradient(sol2, p_val)[1] + + # Gradients should match (no more doubling) + @test grad1 ≈ grad2 rtol=1e-10 + @test !(grad1 ≈ 2 * grad2) # Should NOT be doubled anymore + end + + # Test 4: Direct ScaledOperator constructor (the specific case our rrule fixes) + @testset "Direct ScaledOperator constructor" begin + func = p -> 3.0 * p + + function test_direct_constructor(p) + S = ScalarOperator(func(p)) + M = MatrixOperator([2.0 1.0; 1.0 2.0]) + scaled = ScaledOperator(S, M) # This should use our rrule + return scaled.λ.val + end + + p_val = 0.5 + result = test_direct_constructor(p_val) + grad = Zygote.gradient(test_direct_constructor, p_val)[1] + + @test result ≈ func(p_val) + @test grad ≈ 3.0 # Should not be doubled (6.0) + end +end \ No newline at end of file