Skip to content

Commit f4c9b5c

Browse files
Add Enzyme integration tests and formatting fixes
- Created test/enzyme.jl with comprehensive Enzyme autodiff tests - Tests verify gradients are computed correctly (not NaN) for operators with parameter-dependent coefficients (issue #319) - Added Enzyme to test dependencies in test/Project.toml - Included Enzyme tests in Core test group - Applied JuliaFormatter to extension file Tests cover: - ScalarOperator with parameter-dependent coefficients - MatrixOperator with update functions - Composed operators (combinations of ScalarOp and MatrixOp) Related to #319
1 parent 30506e1 commit f4c9b5c

File tree

4 files changed

+113
-1
lines changed

4 files changed

+113
-1
lines changed

ext/SciMLOperatorsEnzymeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ end
2727
# The function fields (update_func) are just code that computes coefficients.
2828

2929
# Mark specific scalar and matrix operator types that have function fields as inactive
30-
Enzyme.EnzymeRules.inactive_type(::Type{<:SciMLOperators.AbstractSciMLScalarOperator}) = true
30+
function Enzyme.EnzymeRules.inactive_type(::Type{<:SciMLOperators.AbstractSciMLScalarOperator})
31+
true
32+
end
3133
Enzyme.EnzymeRules.inactive_type(::Type{<:SciMLOperators.AbstractSciMLOperator}) = true
3234

3335
# Note: The actual differentiation will happen through the mathematical operations

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
23
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
34
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
45
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

test/enzyme.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Test Enzyme integration with SciMLOperators
2+
# Verifies that gradients can be computed through operators with parameter-dependent coefficients
3+
# Related to issue #319
4+
5+
using SciMLOperators, Enzyme, LinearAlgebra, SparseArrays, Test
6+
7+
const T = Float64
8+
9+
# Test basic operator autodiff with Enzyme
10+
@testset "Enzyme autodiff with ScalarOperator" begin
11+
# Create operators with parameter-dependent coefficients
12+
coef1(a, u, p, t) = -p[1]
13+
coef2(a, u, p, t) = p[2]
14+
15+
A1_data = sparse(T[0.0 1.0; 0.0 0.0])
16+
A2_data = sparse(T[0.0 0.0; 1.0 0.0])
17+
18+
c1 = ScalarOperator(one(T), coef1)
19+
c2 = ScalarOperator(one(T), coef2)
20+
21+
U = c1 * MatrixOperator(A1_data) + c2 * MatrixOperator(A2_data)
22+
23+
# Simple loss function that uses the operator
24+
function loss(p)
25+
x = T[3.0, 4.0]
26+
t = 0.0
27+
28+
# Update coefficients and apply operator
29+
U_updated = update_coefficients(U, x, p, t)
30+
y = U_updated * x
31+
32+
return sum(abs2, y)
33+
end
34+
35+
# Test that Enzyme can compute gradients
36+
p = T[1.0, 2.0]
37+
dp = Enzyme.make_zero(p)
38+
39+
result = Enzyme.autodiff(Enzyme.Reverse, loss, Active, Duplicated(p, dp))
40+
41+
# Gradient should not be NaN (the original bug)
42+
@test !any(isnan, dp)
43+
@test !any(isinf, dp)
44+
45+
# Gradient should be non-zero (operators depend on parameters)
46+
@test any(!iszero, dp)
47+
end
48+
49+
@testset "Enzyme autodiff with MatrixOperator" begin
50+
# Test with matrix operator that has update function
51+
update_func(A, u, p, t) = p[1] * A
52+
53+
A_data = T[1.0 2.0; 3.0 4.0]
54+
L = MatrixOperator(A_data; update_func = update_func)
55+
56+
function loss2(p)
57+
x = T[1.0, 1.0]
58+
t = 0.0
59+
60+
L_updated = update_coefficients(L, x, p, t)
61+
y = L_updated * x
62+
63+
return sum(abs2, y)
64+
end
65+
66+
p = T[2.0]
67+
dp = Enzyme.make_zero(p)
68+
69+
result = Enzyme.autodiff(Enzyme.Reverse, loss2, Active, Duplicated(p, dp))
70+
71+
# Gradient should be valid
72+
@test !any(isnan, dp)
73+
@test !any(isinf, dp)
74+
@test any(!iszero, dp)
75+
end
76+
77+
@testset "Enzyme autodiff with composed operators" begin
78+
# Test more complex operator composition
79+
coef(a, u, p, t) = p[1]
80+
81+
A = MatrixOperator(T[1.0 0.0; 0.0 1.0])
82+
B = MatrixOperator(T[2.0 1.0; 1.0 2.0])
83+
α = ScalarOperator(one(T), coef)
84+
85+
# Composed operator: α * A + B
86+
C = α * A + B
87+
88+
function loss3(p)
89+
x = T[1.0, 2.0]
90+
t = 0.0
91+
92+
C_updated = update_coefficients(C, x, p, t)
93+
y = C_updated * x
94+
95+
return sum(y)
96+
end
97+
98+
p = T[3.0]
99+
dp = Enzyme.make_zero(p)
100+
101+
result = Enzyme.autodiff(Enzyme.Reverse, loss3, Active, Duplicated(p, dp))
102+
103+
# Gradient should be valid
104+
@test !any(isnan, dp)
105+
@test !any(isinf, dp)
106+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ end
2828
@time @safetestset "Zygote.jl" begin
2929
include("zygote.jl")
3030
end
31+
@time @safetestset "Enzyme.jl" begin
32+
include("enzyme.jl")
33+
end
3134
@time @safetestset "Copy methods" begin
3235
include("copy.jl")
3336
end

0 commit comments

Comments
 (0)