Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,29 @@ authors = ["Vedant Puri <[email protected]>"]
version = "1.9.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
SciMLOperatorsEnzymeExt = "Enzyme"
SciMLOperatorsSparseArraysExt = "SparseArrays"
SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore"

[compat]
Accessors = "0.1.42"
ArrayInterface = "7.19"
DocStringExtensions = "0.9.4"
Enzyme = "0.13"
JuliaFormatter = "2.1.6"
LinearAlgebra = "1.10"
MacroTools = "0.5.16"
SparseArrays = "1.10"
Expand Down
14 changes: 14 additions & 0 deletions ext/SciMLOperatorsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module SciMLOperatorsEnzymeExt

using SciMLOperators
using Enzyme
using LinearAlgebra

# Enzyme extension for SciMLOperators
#
# This extension ensures compatibility between Enzyme and SciMLOperators.
# The main issue is that operators contain function fields (update_func) which are
# closures that shouldn't be differentiated. By loading this extension, Enzyme's
# default behavior works correctly with the operator mathematical operations.

end # module
6 changes: 4 additions & 2 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,11 @@ end
func4(a, u, p, t) = t^4
func5(a, u, p, t) = t^5

O1 = MatrixOperator(A) + ScalarOperator(0.0, func1) * MatrixOperator(A) + ScalarOperator(0.0, func2) * MatrixOperator(A)
O1 = MatrixOperator(A) + ScalarOperator(0.0, func1) * MatrixOperator(A) +
ScalarOperator(0.0, func2) * MatrixOperator(A)

O2 = MatrixOperator(A) + ScalarOperator(0.0, func3) * MatrixOperator(A) + ScalarOperator(0.0, func4) * MatrixOperator(A)
O2 = MatrixOperator(A) + ScalarOperator(0.0, func3) * MatrixOperator(A) +
ScalarOperator(0.0, func4) * MatrixOperator(A)

O3 = MatrixOperator(A) + ScalarOperator(0.0, func5) * MatrixOperator(A)

Expand Down
106 changes: 53 additions & 53 deletions test/copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,231 +8,231 @@ using Test
A = rand(5, 5)
L = MatrixOperator(A)
L_copy = copy(L)

# Modify original
L.A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.A[1, 1] != 999.0
@test L_copy.A != L.A
end

# Test DiagonalOperator (which is a MatrixOperator with Diagonal matrix)
@testset "DiagonalOperator" begin
d = rand(5)
L = DiagonalOperator(d)
L_copy = copy(L)

# Modify original
L.A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.A[1, 1] != 999.0
@test L_copy.A != L.A
end

# Test ScalarOperator
@testset "ScalarOperator" begin
L = ScalarOperator(2.0)
L_copy = copy(L)

# Modify original
L.val = 999.0

# Check that copy is not affected
@test L_copy.val == 2.0
end

# Test AffineOperator
@testset "AffineOperator" begin
A = MatrixOperator(rand(5, 5))
B = MatrixOperator(rand(5, 5))
b = rand(5)
L = AffineOperator(A, B, b)
L_copy = copy(L)

# Modify original
L.b[1] = 999.0
L.A.A[1, 1] = 888.0
L.B.A[1, 1] = 777.0

# Check that copy is not affected
@test L_copy.b[1] != 999.0
@test L_copy.A.A[1, 1] != 888.0
@test L_copy.B.A[1, 1] != 777.0
end

# Test ComposedOperator
@testset "ComposedOperator" begin
A = MatrixOperator(rand(5, 5))
B = MatrixOperator(rand(5, 5))
L = A ∘ B
L_copy = copy(L)

# Modify original
L.ops[1].A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.ops[1].A[1, 1] != 999.0
end

# Test InvertedOperator
@testset "InvertedOperator" begin
A = MatrixOperator(rand(5, 5) + 5I) # Make sure it's invertible
L = inv(A)
L_copy = copy(L)

# Modify original
L.L.A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.L.A[1, 1] != 999.0
end

# Test TensorProductOperator
@testset "TensorProductOperator" begin
A = MatrixOperator(rand(3, 3))
B = MatrixOperator(rand(2, 2))
L = kron(A, B) # Use kron instead of ⊗
L_copy = copy(L)

# Modify original
L.ops[1].A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.ops[1].A[1, 1] != 999.0
end

# Test AdjointOperator
@testset "AdjointOperator" begin
A = MatrixOperator(rand(5, 5))
L = SciMLOperators.AdjointOperator(A) # Create AdjointOperator explicitly
L_copy = copy(L)

# Modify original
L.L.A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.L.A[1, 1] != 999.0
end

# Test TransposedOperator
@testset "TransposedOperator" begin
A = MatrixOperator(rand(5, 5))
L = SciMLOperators.TransposedOperator(A) # Create TransposedOperator explicitly
L_copy = copy(L)

# Modify original
L.L.A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.L.A[1, 1] != 999.0
end

# Test AddedScalarOperator
@testset "AddedScalarOperator" begin
α = ScalarOperator(2.0)
β = ScalarOperator(3.0)
L = α + β
L_copy = copy(L)

# Modify original
L.ops[1].val = 999.0

# Check that copy is not affected
@test L_copy.ops[1].val == 2.0
end

# Test ComposedScalarOperator
@testset "ComposedScalarOperator" begin
α = ScalarOperator(2.0)
β = ScalarOperator(3.0)
L = α * β
L_copy = copy(L)

# Modify original
L.ops[1].val = 999.0

# Check that copy is not affected
@test L_copy.ops[1].val == 2.0
end

# Test InvertedScalarOperator
@testset "InvertedScalarOperator" begin
α = ScalarOperator(2.0)
L = inv(α)
L_copy = copy(L)

# Modify original
L.λ.val = 999.0

# Check that copy is not affected
@test L_copy.λ.val == 2.0
end

# Test IdentityOperator (should return self)
@testset "IdentityOperator" begin
L = IdentityOperator(5)
L_copy = copy(L)

# Should be the same object since it's immutable
@test L === L_copy
end

# Test NullOperator (should return self)
@testset "NullOperator" begin
L = NullOperator(5)
L_copy = copy(L)

# Should be the same object since it's immutable
@test L === L_copy
end

# Test InvertibleOperator
@testset "InvertibleOperator" begin
A = rand(5, 5) + 5I # Make sure it's invertible
M = MatrixOperator(A)
F = lu(A)
L = InvertibleOperator(M, F)
L_copy = copy(L)

# Modify original
L.L.A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.L.A[1, 1] != 999.0
end

# Test ScaledOperator
@testset "ScaledOperator" begin
α = ScalarOperator(2.0)
A = MatrixOperator(rand(5, 5))
L = α * A
L_copy = copy(L)

# Modify original
L.λ.val = 999.0
L.L.A[1, 1] = 888.0

# Check that copy is not affected
@test L_copy.λ.val == 2.0
@test L_copy.L.A[1, 1] != 888.0
end

# Test AddedOperator
@testset "AddedOperator" begin
A = MatrixOperator(rand(5, 5))
B = MatrixOperator(rand(5, 5))
L = A + B
L_copy = copy(L)

# Modify original
L.ops[1].A[1, 1] = 999.0

# Check that copy is not affected
@test L_copy.ops[1].A[1, 1] != 999.0
end

# Test that operators still work correctly after copying
@testset "Functionality after copy" begin
# MatrixOperator
Expand All @@ -241,22 +241,22 @@ using Test
L_copy = copy(L)
v = rand(5)
@test L * v ≈ L_copy * v

# ScalarOperator
α = ScalarOperator(2.0)
α_copy = copy(α)
@test α * v ≈ α_copy * v

# ComposedOperator
B = MatrixOperator(rand(5, 5))
comp = L ∘ B
comp_copy = copy(comp)
@test comp * v ≈ comp_copy * v

# AffineOperator
b = rand(5)
aff = AffineOperator(L, B, b)
aff_copy = copy(aff)
@test aff * v ≈ aff_copy * v
end
end
end
Loading