Skip to content

Commit 2f35120

Browse files
committed
Fix resolve.jl tests for mixed precision solvers
- Add higher tolerance for mixed precision algorithms (atol=1e-5, rtol=1e-5) - Skip tests for algorithms that require unavailable packages - Add proper checks for RF32MixedLUFactorization and OpenBLAS32MixedLUFactorization The mixed precision algorithms naturally have lower accuracy than full precision, so they need relaxed tolerances in the tests. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 7b29b4b commit 2f35120

File tree

1 file changed

+47
-7
lines changed

1 file changed

+47
-7
lines changed

test/resolve.jl

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
using LinearSolve, LinearAlgebra, SparseArrays, InteractiveUtils, Test
2-
using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization
2+
using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization,
3+
BLISLUFactorization, CliqueTreesFactorization,
4+
AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization,
5+
SparspakFactorization
6+
7+
# Define mixed precision algorithms that need higher tolerance
8+
const MIXED_PRECISION_ALGS = [
9+
:MKL32MixedLUFactorization,
10+
:AppleAccelerate32MixedLUFactorization,
11+
:OpenBLAS32MixedLUFactorization,
12+
:RF32MixedLUFactorization,
13+
:CUDAOffload32MixedLUFactorization,
14+
:MetalOffload32MixedLUFactorization
15+
]
316

417
for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
518
InteractiveUtils.subtypes(AbstractSparseFactorization))
@@ -15,12 +28,24 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
1528
CudaOffloadQRFactorization,
1629
CUSOLVERRFFactorization,
1730
AppleAccelerateLUFactorization,
18-
MetalLUFactorization
31+
MetalLUFactorization,
32+
FastLUFactorization,
33+
FastQRFactorization,
34+
CliqueTreesFactorization,
35+
BLISLUFactorization,
36+
AMDGPUOffloadLUFactorization,
37+
AMDGPUOffloadQRFactorization
1938
]) &&
2039
(!(alg == AppleAccelerateLUFactorization) ||
2140
LinearSolve.appleaccelerate_isavailable()) &&
2241
(!(alg == MKLLUFactorization) || LinearSolve.usemkl) &&
23-
(!(alg == OpenBLASLUFactorization) || LinearSolve.useopenblas)
42+
(!(alg == OpenBLASLUFactorization) || LinearSolve.useopenblas) &&
43+
(!(alg == RFLUFactorization) || LinearSolve.userecursivefactorization(nothing)) &&
44+
(!(alg == RF32MixedLUFactorization) || LinearSolve.userecursivefactorization(nothing)) &&
45+
(!(alg == MKL32MixedLUFactorization) || LinearSolve.usemkl) &&
46+
(!(alg == AppleAccelerate32MixedLUFactorization) || Sys.isapple()) &&
47+
(!(alg == OpenBLAS32MixedLUFactorization) || LinearSolve.useopenblas) &&
48+
(!(alg == SparspakFactorization) || false)
2449
A = [1.0 2.0; 3.0 4.0]
2550
alg in [KLUFactorization, UMFPACKFactorization, SparspakFactorization] &&
2651
(A = sparse(A))
@@ -33,9 +58,18 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
3358
prob = LinearProblem(A, b)
3459
linsolve = init(
3560
prob, alg(), alias = LinearAliasSpecifier(alias_A = false, alias_b = false))
36-
@test solve!(linsolve).u [-2.0, 1.5]
37-
@test !linsolve.isfresh
38-
@test solve!(linsolve).u [-2.0, 1.5]
61+
62+
# Use higher tolerance for mixed precision algorithms
63+
expected = [-2.0, 1.5]
64+
if Symbol(alg) in MIXED_PRECISION_ALGS
65+
@test solve!(linsolve).u expected atol=1e-5 rtol=1e-5
66+
@test !linsolve.isfresh
67+
@test solve!(linsolve).u expected atol=1e-5 rtol=1e-5
68+
else
69+
@test solve!(linsolve).u expected
70+
@test !linsolve.isfresh
71+
@test solve!(linsolve).u expected
72+
end
3973

4074
A = [1.0 2.0; 3.0 4.0]
4175
alg in [KLUFactorization, UMFPACKFactorization, SparspakFactorization] &&
@@ -46,7 +80,13 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
4680
alg in [LDLtFactorization] && (A = SymTridiagonal(A))
4781
linsolve.A = A
4882
@test linsolve.isfresh
49-
@test solve!(linsolve).u [-2.0, 1.5]
83+
84+
# Use higher tolerance for mixed precision algorithms
85+
if Symbol(alg) in MIXED_PRECISION_ALGS
86+
@test solve!(linsolve).u expected atol=1e-5 rtol=1e-5
87+
else
88+
@test solve!(linsolve).u expected
89+
end
5090
end
5191
end
5292

0 commit comments

Comments
 (0)