Skip to content

Commit 6432716

Browse files
author
Avik Pal
committed
More tests and some safety
1 parent 7c1f1b2 commit 6432716

File tree

4 files changed

+30
-13
lines changed

4 files changed

+30
-13
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ StaticArrays = "1.5"
110110
StaticArraysCore = "1.4.2"
111111
Test = "1"
112112
UnPack = "1"
113+
Zygote = "0.6.69"
113114
julia = "1.10"
114115

115116
[extras]
@@ -137,6 +138,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
137138
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
138139
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
139140
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
141+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
140142

141143
[targets]
142-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs"]
144+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]

ext/LinearSolveHYPREExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
9090
cache = LinearCache{
9191
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
9292
typeof(Pl), typeof(Pr), typeof(reltol),
93-
typeof(__issquare(assumptions), typeof(sensealg))
93+
typeof(__issquare(assumptions)), typeof(sensealg)
9494
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
9595
maxiters, verbose, assumptions, sensealg)
9696
return cache

src/adjoint.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
4747
A_ = alias_A ? deepcopy(A) : A
4848
end
4949
else
50-
if alg isa DefaultLinearSolver
51-
A_ = deepcopy(A)
52-
else
53-
A_ = alias_A ? deepcopy(A) : A
54-
end
50+
A_ = deepcopy(A)
5551
end
5652

5753
sol = solve!(cache)

test/adjoint.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,34 @@ end
5151

5252
dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)
5353

54-
#= Needs ForwardDiff rules
55-
dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
56-
db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
57-
db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1))
54+
dA2 = FiniteDiff.finite_difference_gradient(
55+
x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
56+
db12 = FiniteDiff.finite_difference_gradient(
57+
x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
58+
db22 = FiniteDiff.finite_difference_gradient(
59+
x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))
60+
61+
@test dAdA2 atol=5e-5
62+
@test db1 db12
63+
@test db2 db22
64+
65+
function f4(A, b1, b2; alg = LUFactorization())
66+
prob = LinearProblem(A, b1)
67+
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))
68+
prob = LinearProblem(A, b2)
69+
sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES()))
70+
norm(sol1.u .+ sol2.u)
71+
end
72+
73+
dA, db1, db2 = Zygote.gradient(f4, A, b1, b1)
74+
75+
dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
76+
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
77+
db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))
5878

59-
@test dAdA2 atol=5e-5
79+
@test dAdA2 atol=5e-5
6080
@test db1 db12
6181
@test db2 db22
62-
=#
6382

6483
A = rand(n, n);
6584
b1 = rand(n);

0 commit comments

Comments
 (0)