Skip to content

Commit 47c0df1

Browse files
committed
Fix mixed precision detection in resolve.jl tests
Use string matching to detect mixed precision algorithms instead of symbol comparison. This ensures the tolerance branch is properly taken for algorithms like RF32MixedLUFactorization.
1 parent 2ecb137 commit 47c0df1

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

test/resolve.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,11 @@ using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization,
44
AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization,
55
SparspakFactorization
66

7-
# Define mixed precision algorithms that need higher tolerance
8-
const MIXED_PRECISION_ALGS = [
9-
:MKL32MixedLUFactorization,
10-
:AppleAccelerate32MixedLUFactorization,
11-
:OpenBLAS32MixedLUFactorization,
12-
:RF32MixedLUFactorization,
13-
:CUDAOffload32MixedLUFactorization,
14-
:MetalOffload32MixedLUFactorization
15-
]
7+
# Function to check if an algorithm is mixed precision
8+
function is_mixed_precision_alg(alg)
9+
alg_name = string(alg)
10+
return contains(alg_name, "32Mixed") || contains(alg_name, "Mixed32")
11+
end
1612

1713
for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
1814
InteractiveUtils.subtypes(AbstractSparseFactorization))
@@ -61,7 +57,7 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
6157

6258
# Use higher tolerance for mixed precision algorithms
6359
expected = [-2.0, 1.5]
64-
if Symbol(alg) in MIXED_PRECISION_ALGS
60+
if is_mixed_precision_alg(alg)
6561
@test solve!(linsolve).u expected atol=1e-4 rtol=1e-4
6662
@test !linsolve.isfresh
6763
@test solve!(linsolve).u expected atol=1e-4 rtol=1e-4
@@ -82,7 +78,7 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
8278
@test linsolve.isfresh
8379

8480
# Use higher tolerance for mixed precision algorithms
85-
if Symbol(alg) in MIXED_PRECISION_ALGS
81+
if is_mixed_precision_alg(alg)
8682
@test solve!(linsolve).u expected atol=1e-4 rtol=1e-4
8783
else
8884
@test solve!(linsolve).u expected

0 commit comments

Comments
 (0)