@@ -13,37 +13,38 @@ Uses more strict rules for BLAS-dependent algorithms with non-standard types.
1313function test_algorithm_compatibility (alg, eltype:: Type , test_size:: Int = 4 )
1414 # Get algorithm name for type-specific compatibility rules
1515 alg_name = string (typeof (alg). name. name)
16-
16+
1717 # Define strict compatibility rules for BLAS-dependent algorithms
18- if ! (eltype <: LinearAlgebra.BLAS.BlasFloat ) && alg_name in [" BLISFactorization" , " MKLLUFactorization" , " AppleAccelerateLUFactorization" ]
18+ if ! (eltype <: LinearAlgebra.BLAS.BlasFloat ) && alg_name in [
19+ " BLISFactorization" , " MKLLUFactorization" , " AppleAccelerateLUFactorization" ]
1920 return false # BLAS algorithms not compatible with non-standard types
2021 end
2122
2223 if alg_name == " BLISLUFactorization" && Sys. isapple ()
2324 return false # BLISLUFactorization has no Apple Silicon binary
2425 end
25-
26+
2627 # For standard types or algorithms that passed the strict check, test functionality
2728 try
2829 # Create a small test problem with the specified element type
2930 rng = MersenneTwister (123 )
3031 A = rand (rng, eltype, test_size, test_size)
3132 b = rand (rng, eltype, test_size)
3233 u0 = rand (rng, eltype, test_size)
33-
34+
3435 prob = LinearProblem (A, b; u0 = u0)
35-
36+
3637 # Try to solve - if it works, the algorithm is compatible
3738 sol = solve (prob, alg)
38-
39+
3940 # Additional check: verify the solution is actually of the expected type
4041 if ! isa (sol. u, AbstractVector{eltype})
4142 @debug " Algorithm $alg_name returned wrong element type for $eltype "
4243 return false
4344 end
44-
45+
4546 return true
46-
47+
4748 catch e
4849 # Algorithm failed - not compatible with this element type
4950 @debug " Algorithm $alg_name failed for $eltype : $e "
@@ -60,14 +61,14 @@ Returns filtered algorithms and names.
6061function filter_compatible_algorithms (algorithms, alg_names, eltype:: Type )
6162 compatible_algs = []
6263 compatible_names = String[]
63-
64+
6465 for (alg, name) in zip (algorithms, alg_names)
6566 if test_algorithm_compatibility (alg, eltype)
6667 push! (compatible_algs, alg)
6768 push! (compatible_names, name)
6869 end
6970 end
70-
71+
7172 return compatible_algs, compatible_names
7273end
7374
@@ -89,36 +90,37 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
8990
9091 # Initialize results DataFrame
9192 results_data = []
92-
93+
9394 # Calculate total number of benchmarks for progress bar
9495 total_benchmarks = 0
9596 for eltype in eltypes
9697 # Pre-filter to estimate the actual number
9798 test_algs, _ = filter_compatible_algorithms (algorithms, alg_names, eltype)
9899 total_benchmarks += length (matrix_sizes) * length (test_algs)
99100 end
100-
101+
101102 # Create progress bar
102- progress = Progress (total_benchmarks, desc= " Benchmarking: " ,
103- barlen= 50 , showspeed= true )
103+ progress = Progress (total_benchmarks, desc = " Benchmarking: " ,
104+ barlen = 50 , showspeed = true )
104105
105106 try
106107 for eltype in eltypes
107108 # Filter algorithms for this element type
108- compatible_algs, compatible_names = filter_compatible_algorithms (algorithms, alg_names, eltype)
109-
109+ compatible_algs,
110+ compatible_names = filter_compatible_algorithms (algorithms, alg_names, eltype)
111+
110112 if isempty (compatible_algs)
111113 @warn " No algorithms compatible with $eltype , skipping..."
112114 continue
113115 end
114-
116+
115117 for n in matrix_sizes
116118 # Create test problem with specified element type
117119 rng = MersenneTwister (123 ) # Consistent seed for reproducibility
118120 A = rand (rng, eltype, n, n)
119121 b = rand (rng, eltype, n)
120122 u0 = rand (rng, eltype, n)
121-
123+
122124 # Compute reference solution with LUFactorization if correctness check is enabled
123125 reference_solution = nothing
124126 if check_correctness
@@ -133,9 +135,9 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
133135
134136 for (alg, name) in zip (compatible_algs, compatible_names)
135137 # Update progress description
136- ProgressMeter. update! (progress,
137- desc= " Benchmarking $name on $(n) ×$(n) $eltype matrix: " )
138-
138+ ProgressMeter. update! (progress,
139+ desc = " Benchmarking $name on $(n) ×$(n) $eltype matrix: " )
140+
139141 gflops = 0.0
140142 success = true
141143 error_msg = " "
@@ -149,12 +151,13 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
149151
150152 # Warmup run and correctness check
151153 warmup_sol = solve (prob, alg)
152-
154+
153155 # Check correctness if reference solution is available
154156 if check_correctness && reference_solution != = nothing
155157 # Compute relative error
156- rel_error = norm (warmup_sol. u - reference_solution. u) / norm (reference_solution. u)
157-
158+ rel_error = norm (warmup_sol. u - reference_solution. u) /
159+ norm (reference_solution. u)
160+
158161 if rel_error > correctness_tol
159162 passed_correctness = false
160163 @warn " Algorithm $name failed correctness check for size $n , eltype $eltype . " *
@@ -164,7 +167,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
164167 error_msg = " Failed correctness check (rel_error = $(round (rel_error, sigdigits= 3 )) )"
165168 end
166169 end
167-
170+
168171 # Only benchmark if correctness check passed
169172 if passed_correctness
170173 # Actual benchmark
@@ -195,7 +198,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
195198 success = success,
196199 error = error_msg
197200 ))
198-
201+
199202 # Update progress
200203 ProgressMeter. next! (progress)
201204 end
@@ -216,15 +219,16 @@ end
216219Get the matrix sizes to benchmark based on the requested size categories.
217220
218221Size categories:
219- - `:tiny` - 5:5:20 (for very small problems)
220- - `:small` - 20:20:100 (for small problems)
221- - `:medium` - 100:50:300 (for typical problems)
222- - `:large` - 300:100:1000 (for larger problems)
223- - `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
222+
223+ - `:tiny` - 5:5:20 (for very small problems)
224+ - `:small` - 20:20:100 (for small problems)
225+ - `:medium` - 100:50:300 (for typical problems)
226+ - `:large` - 300:100:1000 (for larger problems)
227+ - `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
224228"""
225229function get_benchmark_sizes (size_categories:: Vector{Symbol} )
226230 sizes = Int[]
227-
231+
228232 for category in size_categories
229233 if category == :tiny
230234 append! (sizes, 5 : 5 : 20 )
@@ -240,7 +244,7 @@ function get_benchmark_sizes(size_categories::Vector{Symbol})
240244 @warn " Unknown size category: $category . Skipping."
241245 end
242246 end
243-
247+
244248 # Remove duplicates and sort
245249 return sort (unique (sizes))
246250end
@@ -277,10 +281,10 @@ function categorize_results(df::DataFrame)
277281
278282 for eltype in eltypes
279283 @info " Categorizing results for element type: $eltype "
280-
284+
281285 # Filter results for this element type
282286 eltype_df = filter (row -> row. eltype == eltype, successful_df)
283-
287+
284288 if nrow (eltype_df) == 0
285289 continue
286290 end
@@ -295,24 +299,27 @@ function categorize_results(df::DataFrame)
295299
296300 # Calculate average GFLOPs for each algorithm in this range
297301 avg_results = combine (groupby (range_df, :algorithm ), :gflops => mean => :avg_gflops )
298-
302+
299303 # Sort by performance
300- sort! (avg_results, :avg_gflops , rev= true )
304+ sort! (avg_results, :avg_gflops , rev = true )
301305
302306 # Find the best algorithm (for complex types, avoid RFLU if possible)
303307 if nrow (avg_results) > 0
304308 best_alg = avg_results. algorithm[1 ]
305-
309+
306310 # For complex types, check if best is RFLU and we have alternatives
307- if (eltype == " ComplexF32" || eltype == " ComplexF64" ) &&
308- (contains (best_alg, " RFLU" ) || contains (best_alg, " RecursiveFactorization" ))
309-
311+ if (eltype == " ComplexF32" || eltype == " ComplexF64" ) &&
312+ (contains (best_alg, " RFLU" ) ||
313+ contains (best_alg, " RecursiveFactorization" ))
314+
310315 # Look for the best non-RFLU algorithm
311316 for i in 2 : nrow (avg_results)
312317 alt_alg = avg_results. algorithm[i]
313- if ! contains (alt_alg, " RFLU" ) && ! contains (alt_alg, " RecursiveFactorization" )
318+ if ! contains (alt_alg, " RFLU" ) &&
319+ ! contains (alt_alg, " RecursiveFactorization" )
314320 # Check if performance difference is not too large (within 20%)
315- perf_ratio = avg_results. avg_gflops[i] / avg_results. avg_gflops[1 ]
321+ perf_ratio = avg_results. avg_gflops[i] /
322+ avg_results. avg_gflops[1 ]
316323 if perf_ratio > 0.8
317324 @info " Using $alt_alg instead of $best_alg for $eltype at $range_name ($(round (100 * perf_ratio, digits= 1 )) % of RFLU performance) to avoid complex number issues"
318325 best_alg = alt_alg
@@ -323,7 +330,7 @@ function categorize_results(df::DataFrame)
323330 end
324331 end
325332 end
326-
333+
327334 category_key = " $(eltype) _$(range_name) "
328335 categories[category_key] = best_alg
329336 best_idx = findfirst (== (best_alg), avg_results. algorithm)
0 commit comments