Skip to content

Commit f065a08

Browse files
authored
fix: only factorize for direct linear solve (#181)
1 parent 8c30ba7 commit f065a08

File tree

3 files changed

+32
-22
lines changed

3 files changed

+32
-22
lines changed

src/execution.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646
function build_A_aux(
4747
::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend
4848
)
49-
(; conditions, backends) = implicit
49+
(; conditions, linear_solver, backends) = implicit
5050
(; prep_A) = prep
5151
actual_backend = isnothing(backends) ? suggested_backend : backends.y
5252
contexts = (Constant(x), Constant(z), map(Constant, args)...)
@@ -56,7 +56,11 @@ function build_A_aux(
5656
else
5757
A = jacobian(f, prep_A, actual_backend, y, contexts...)
5858
end
59-
return factorize(A)
59+
if linear_solver isa DirectLinearSolver
60+
return factorize(A)
61+
else
62+
return A
63+
end
6064
end
6165

6266
function build_A_aux(
@@ -100,7 +104,7 @@ end
100104
function build_Aᵀ_aux(
101105
::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend
102106
)
103-
(; conditions, backends) = implicit
107+
(; conditions, linear_solver, backends) = implicit
104108
(; prep_Aᵀ) = prep
105109
actual_backend = isnothing(backends) ? suggested_backend : backends.y
106110
contexts = (Constant(x), Constant(z), map(Constant, args)...)
@@ -110,7 +114,11 @@ function build_Aᵀ_aux(
110114
else
111115
Aᵀ = transpose(jacobian(f, prep_Aᵀ, actual_backend, y, contexts...))
112116
end
113-
return factorize(Aᵀ)
117+
if linear_solver isa DirectLinearSolver
118+
return factorize(Aᵀ)
119+
else
120+
return Aᵀ
121+
end
114122
end
115123

116124
function build_Aᵀ_aux(

test/preparation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
using ADTypes
66
using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode
77
using ForwardDiff: ForwardDiff
8+
using LinearAlgebra: Factorization, TransposeFactorization
89
using Zygote: Zygote
910
using Test
1011

12+
const GenericMatrix = Union{AbstractMatrix,Factorization,TransposeFactorization}
13+
1114
solver(x) = sqrt.(x), nothing
1215
conditions(x, y, z) = y .^ 2 .- x
1316

@@ -41,8 +44,8 @@
4144
@test prep.prep_Aᵀ === nothing
4245
@test prep.prep_B !== nothing
4346
@test prep.prep_Bᵀ === nothing
44-
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
45-
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
47+
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
48+
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
4649
@test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP
4750
@test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP
4851
end
@@ -53,8 +56,8 @@
5356
@test prep.prep_Aᵀ !== nothing
5457
@test prep.prep_B === nothing
5558
@test prep.prep_Bᵀ !== nothing
56-
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
57-
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
59+
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
60+
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
5861
@test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP
5962
@test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP
6063
end

test/systematic.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,44 @@
11
using TestItems
22

3-
@testitem "Direct" setup = [TestUtils] begin
3+
@testitem "Matrix" setup = [TestUtils] begin
44
using ADTypes, .TestUtils
5-
for (backends, x) in
6-
Iterators.product([nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3)])
5+
representation = MatrixRepresentation()
6+
for (linear_solver, backends, x) in Iterators.product(
7+
[DirectLinearSolver(), IterativeLinearSolver()],
8+
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
9+
[float.(1:3)],
10+
)
711
yield()
812
scen = Scenario(;
913
solver=default_solver,
1014
conditions=default_conditions,
1115
x=x,
12-
implicit_kwargs=(;
13-
representation=MatrixRepresentation(),
14-
linear_solver=DirectLinearSolver(),
15-
backends,
16-
),
16+
implicit_kwargs=(; representation, linear_solver, backends),
1717
)
1818
scen2 = add_arg_mult(scen)
1919
test_implicit(scen)
2020
test_implicit(scen2)
2121
end
2222
end;
2323

24-
@testitem "Iterative" setup = [TestUtils] begin
24+
@testitem "Operator" setup = [TestUtils] begin
2525
using ADTypes, .TestUtils
26-
for (backends, linear_solver, x) in Iterators.product(
27-
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
26+
representation = OperatorRepresentation()
27+
for (linear_solver, backends, x) in Iterators.product(
2828
[
2929
IterativeLinearSolver(),
3030
IterativeLinearSolver(; rtol=1e-8),
3131
IterativeLinearSolver(; issymmetric=true, isposdef=true),
3232
],
33+
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
3334
[float.(1:3), reshape(float.(1:6), 3, 2)],
3435
)
3536
yield()
3637
scen = Scenario(;
3738
solver=default_solver,
3839
conditions=default_conditions,
3940
x=x,
40-
implicit_kwargs=(;
41-
representation=OperatorRepresentation(), linear_solver, backends
42-
),
41+
implicit_kwargs=(; representation, linear_solver, backends),
4342
)
4443
scen2 = add_arg_mult(scen)
4544
test_implicit(scen; type_stability=VERSION >= v"1.11")

0 commit comments

Comments
 (0)