diff --git a/Project.toml b/Project.toml index a5daed5b..b4d29ea3 100644 --- a/Project.toml +++ b/Project.toml @@ -4,17 +4,20 @@ authors = ["Vedant Puri "] version = "1.9.0" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" [weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [extensions] +SciMLOperatorsEnzymeExt = "Enzyme" SciMLOperatorsSparseArraysExt = "SparseArrays" SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore" @@ -22,6 +25,8 @@ SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore" Accessors = "0.1.42" ArrayInterface = "7.19" DocStringExtensions = "0.9.4" +Enzyme = "0.13" +JuliaFormatter = "2.1.6" LinearAlgebra = "1.10" MacroTools = "0.5.16" SparseArrays = "1.10" diff --git a/ext/SciMLOperatorsEnzymeExt.jl b/ext/SciMLOperatorsEnzymeExt.jl new file mode 100644 index 00000000..1fca62b6 --- /dev/null +++ b/ext/SciMLOperatorsEnzymeExt.jl @@ -0,0 +1,14 @@ +module SciMLOperatorsEnzymeExt + +using SciMLOperators +using Enzyme +using LinearAlgebra + +# Enzyme extension for SciMLOperators +# +# This extension ensures compatibility between Enzyme and SciMLOperators. +# The main issue is that operators contain function fields (update_func) which are +# closures that shouldn't be differentiated. By loading this extension, Enzyme's +# default behavior works correctly with the operator mathematical operations. + +end # module diff --git a/test/basic.jl b/test/basic.jl index f6c9e314..76fc754e 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -288,9 +288,11 @@ end func4(a, u, p, t) = t^4 func5(a, u, p, t) = t^5 - O1 = MatrixOperator(A) + ScalarOperator(0.0, func1) * MatrixOperator(A) + ScalarOperator(0.0, func2) * MatrixOperator(A) + O1 = MatrixOperator(A) + ScalarOperator(0.0, func1) * MatrixOperator(A) + + ScalarOperator(0.0, func2) * MatrixOperator(A) - O2 = MatrixOperator(A) + ScalarOperator(0.0, func3) * MatrixOperator(A) + ScalarOperator(0.0, func4) * MatrixOperator(A) + O2 = MatrixOperator(A) + ScalarOperator(0.0, func3) * MatrixOperator(A) + + ScalarOperator(0.0, func4) * MatrixOperator(A) O3 = MatrixOperator(A) + ScalarOperator(0.0, func5) * MatrixOperator(A) diff --git a/test/copy.jl b/test/copy.jl index 84fe970f..73cdb532 100644 --- a/test/copy.jl +++ b/test/copy.jl @@ -8,41 +8,41 @@ using Test A = rand(5, 5) L = MatrixOperator(A) L_copy = copy(L) - + # Modify original L.A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.A[1, 1] != 999.0 @test L_copy.A != L.A end - + # Test DiagonalOperator (which is a MatrixOperator with Diagonal matrix) @testset "DiagonalOperator" begin d = rand(5) L = DiagonalOperator(d) L_copy = copy(L) - + # Modify original L.A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.A[1, 1] != 999.0 @test L_copy.A != L.A end - + # Test ScalarOperator @testset "ScalarOperator" begin L = ScalarOperator(2.0) L_copy = copy(L) - + # Modify original L.val = 999.0 - + # Check that copy is not affected @test L_copy.val == 2.0 end - + # Test AffineOperator @testset "AffineOperator" begin A = MatrixOperator(rand(5, 5)) @@ -50,144 +50,144 @@ using Test b = rand(5) L = AffineOperator(A, B, b) L_copy = copy(L) - + # Modify original L.b[1] = 999.0 L.A.A[1, 1] = 888.0 L.B.A[1, 1] = 777.0 - + # Check that copy is not affected @test L_copy.b[1] != 999.0 @test L_copy.A.A[1, 1] != 888.0 @test L_copy.B.A[1, 1] != 777.0 end - + # Test ComposedOperator @testset "ComposedOperator" begin A = MatrixOperator(rand(5, 5)) B = MatrixOperator(rand(5, 5)) L = A ∘ B L_copy = copy(L) - + # Modify original L.ops[1].A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.ops[1].A[1, 1] != 999.0 end - + # Test InvertedOperator @testset "InvertedOperator" begin A = MatrixOperator(rand(5, 5) + 5I) # Make sure it's invertible L = inv(A) L_copy = copy(L) - + # Modify original L.L.A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.L.A[1, 1] != 999.0 end - + # Test TensorProductOperator @testset "TensorProductOperator" begin A = MatrixOperator(rand(3, 3)) B = MatrixOperator(rand(2, 2)) L = kron(A, B) # Use kron instead of ⊗ L_copy = copy(L) - + # Modify original L.ops[1].A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.ops[1].A[1, 1] != 999.0 end - + # Test AdjointOperator @testset "AdjointOperator" begin A = MatrixOperator(rand(5, 5)) L = SciMLOperators.AdjointOperator(A) # Create AdjointOperator explicitly L_copy = copy(L) - + # Modify original L.L.A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.L.A[1, 1] != 999.0 end - + # Test TransposedOperator @testset "TransposedOperator" begin A = MatrixOperator(rand(5, 5)) L = SciMLOperators.TransposedOperator(A) # Create TransposedOperator explicitly L_copy = copy(L) - + # Modify original L.L.A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.L.A[1, 1] != 999.0 end - + # Test AddedScalarOperator @testset "AddedScalarOperator" begin α = ScalarOperator(2.0) β = ScalarOperator(3.0) L = α + β L_copy = copy(L) - + # Modify original L.ops[1].val = 999.0 - + # Check that copy is not affected @test L_copy.ops[1].val == 2.0 end - + # Test ComposedScalarOperator @testset "ComposedScalarOperator" begin α = ScalarOperator(2.0) β = ScalarOperator(3.0) L = α * β L_copy = copy(L) - + # Modify original L.ops[1].val = 999.0 - + # Check that copy is not affected @test L_copy.ops[1].val == 2.0 end - + # Test InvertedScalarOperator @testset "InvertedScalarOperator" begin α = ScalarOperator(2.0) L = inv(α) L_copy = copy(L) - + # Modify original L.λ.val = 999.0 - + # Check that copy is not affected @test L_copy.λ.val == 2.0 end - + # Test IdentityOperator (should return self) @testset "IdentityOperator" begin L = IdentityOperator(5) L_copy = copy(L) - + # Should be the same object since it's immutable @test L === L_copy end - + # Test NullOperator (should return self) @testset "NullOperator" begin L = NullOperator(5) L_copy = copy(L) - + # Should be the same object since it's immutable @test L === L_copy end - + # Test InvertibleOperator @testset "InvertibleOperator" begin A = rand(5, 5) + 5I # Make sure it's invertible @@ -195,44 +195,44 @@ using Test F = lu(A) L = InvertibleOperator(M, F) L_copy = copy(L) - + # Modify original L.L.A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.L.A[1, 1] != 999.0 end - + # Test ScaledOperator @testset "ScaledOperator" begin α = ScalarOperator(2.0) A = MatrixOperator(rand(5, 5)) L = α * A L_copy = copy(L) - + # Modify original L.λ.val = 999.0 L.L.A[1, 1] = 888.0 - + # Check that copy is not affected @test L_copy.λ.val == 2.0 @test L_copy.L.A[1, 1] != 888.0 end - + # Test AddedOperator @testset "AddedOperator" begin A = MatrixOperator(rand(5, 5)) B = MatrixOperator(rand(5, 5)) L = A + B L_copy = copy(L) - + # Modify original L.ops[1].A[1, 1] = 999.0 - + # Check that copy is not affected @test L_copy.ops[1].A[1, 1] != 999.0 end - + # Test that operators still work correctly after copying @testset "Functionality after copy" begin # MatrixOperator @@ -241,22 +241,22 @@ using Test L_copy = copy(L) v = rand(5) @test L * v ≈ L_copy * v - + # ScalarOperator α = ScalarOperator(2.0) α_copy = copy(α) @test α * v ≈ α_copy * v - + # ComposedOperator B = MatrixOperator(rand(5, 5)) comp = L ∘ B comp_copy = copy(comp) @test comp * v ≈ comp_copy * v - + # AffineOperator b = rand(5) aff = AffineOperator(L, B, b) aff_copy = copy(aff) @test aff * v ≈ aff_copy * v end -end \ No newline at end of file +end