@@ -13,38 +13,37 @@ Uses more strict rules for BLAS-dependent algorithms with non-standard types.
1313function test_algorithm_compatibility (alg, eltype:: Type , test_size:: Int = 4 )
1414 # Get algorithm name for type-specific compatibility rules
1515 alg_name = string (typeof (alg). name. name)
16-
16+
1717 # Define strict compatibility rules for BLAS-dependent algorithms
18- if ! (eltype <: LinearAlgebra.BLAS.BlasFloat ) && alg_name in [
19- " BLISFactorization" , " MKLLUFactorization" , " AppleAccelerateLUFactorization" ]
18+ if ! (eltype <: LinearAlgebra.BLAS.BlasFloat ) && alg_name in [" BLISFactorization" , " MKLLUFactorization" , " AppleAccelerateLUFactorization" ]
2019 return false # BLAS algorithms not compatible with non-standard types
2120 end
2221
2322 if alg_name == " BLISLUFactorization" && Sys. isapple ()
2423 return false # BLISLUFactorization has no Apple Silicon binary
2524 end
26-
25+
2726 # For standard types or algorithms that passed the strict check, test functionality
2827 try
2928 # Create a small test problem with the specified element type
3029 rng = MersenneTwister (123 )
3130 A = rand (rng, eltype, test_size, test_size)
3231 b = rand (rng, eltype, test_size)
3332 u0 = rand (rng, eltype, test_size)
34-
33+
3534 prob = LinearProblem (A, b; u0 = u0)
36-
35+
3736 # Try to solve - if it works, the algorithm is compatible
3837 sol = solve (prob, alg)
39-
38+
4039 # Additional check: verify the solution is actually of the expected type
4140 if ! isa (sol. u, AbstractVector{eltype})
4241 @debug " Algorithm $alg_name returned wrong element type for $eltype "
4342 return false
4443 end
45-
44+
4645 return true
47-
46+
4847 catch e
4948 # Algorithm failed - not compatible with this element type
5049 @debug " Algorithm $alg_name failed for $eltype : $e "
@@ -61,14 +60,14 @@ Returns filtered algorithms and names.
6160function filter_compatible_algorithms (algorithms, alg_names, eltype:: Type )
6261 compatible_algs = []
6362 compatible_names = String[]
64-
63+
6564 for (alg, name) in zip (algorithms, alg_names)
6665 if test_algorithm_compatibility (alg, eltype)
6766 push! (compatible_algs, alg)
6867 push! (compatible_names, name)
6968 end
7069 end
71-
70+
7271 return compatible_algs, compatible_names
7372end
7473
@@ -90,37 +89,36 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
9089
9190 # Initialize results DataFrame
9291 results_data = []
93-
92+
9493 # Calculate total number of benchmarks for progress bar
9594 total_benchmarks = 0
9695 for eltype in eltypes
9796 # Pre-filter to estimate the actual number
9897 test_algs, _ = filter_compatible_algorithms (algorithms, alg_names, eltype)
9998 total_benchmarks += length (matrix_sizes) * length (test_algs)
10099 end
101-
100+
102101 # Create progress bar
103- progress = Progress (total_benchmarks, desc = " Benchmarking: " ,
104- barlen = 50 , showspeed = true )
102+ progress = Progress (total_benchmarks, desc= " Benchmarking: " ,
103+ barlen = 50 , showspeed= true )
105104
106105 try
107106 for eltype in eltypes
108107 # Filter algorithms for this element type
109- compatible_algs,
110- compatible_names = filter_compatible_algorithms (algorithms, alg_names, eltype)
111-
108+ compatible_algs, compatible_names = filter_compatible_algorithms (algorithms, alg_names, eltype)
109+
112110 if isempty (compatible_algs)
113111 @warn " No algorithms compatible with $eltype , skipping..."
114112 continue
115113 end
116-
114+
117115 for n in matrix_sizes
118116 # Create test problem with specified element type
119117 rng = MersenneTwister (123 ) # Consistent seed for reproducibility
120118 A = rand (rng, eltype, n, n)
121119 b = rand (rng, eltype, n)
122120 u0 = rand (rng, eltype, n)
123-
121+
124122 # Compute reference solution with LUFactorization if correctness check is enabled
125123 reference_solution = nothing
126124 if check_correctness
@@ -135,9 +133,9 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
135133
136134 for (alg, name) in zip (compatible_algs, compatible_names)
137135 # Update progress description
138- ProgressMeter. update! (progress,
139- desc = " Benchmarking $name on $(n) ×$(n) $eltype matrix: " )
140-
136+ ProgressMeter. update! (progress,
137+ desc= " Benchmarking $name on $(n) ×$(n) $eltype matrix: " )
138+
141139 gflops = 0.0
142140 success = true
143141 error_msg = " "
@@ -151,13 +149,12 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
151149
152150 # Warmup run and correctness check
153151 warmup_sol = solve (prob, alg)
154-
152+
155153 # Check correctness if reference solution is available
156154 if check_correctness && reference_solution != = nothing
157155 # Compute relative error
158- rel_error = norm (warmup_sol. u - reference_solution. u) /
159- norm (reference_solution. u)
160-
156+ rel_error = norm (warmup_sol. u - reference_solution. u) / norm (reference_solution. u)
157+
161158 if rel_error > correctness_tol
162159 passed_correctness = false
163160 @warn " Algorithm $name failed correctness check for size $n , eltype $eltype . " *
@@ -167,7 +164,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
167164 error_msg = " Failed correctness check (rel_error = $(round (rel_error, sigdigits= 3 )) )"
168165 end
169166 end
170-
167+
171168 # Only benchmark if correctness check passed
172169 if passed_correctness
173170 # Actual benchmark
@@ -198,7 +195,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
198195 success = success,
199196 error = error_msg
200197 ))
201-
198+
202199 # Update progress
203200 ProgressMeter. next! (progress)
204201 end
@@ -219,16 +216,15 @@ end
219216Get the matrix sizes to benchmark based on the requested size categories.
220217
221218Size categories:
222-
223- - `:tiny` - 5:5:20 (for very small problems)
224- - `:small` - 20:20:100 (for small problems)
225- - `:medium` - 100:50:300 (for typical problems)
226- - `:large` - 300:100:1000 (for larger problems)
227- - `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
219+ - `:tiny` - 5:5:20 (for very small problems)
220+ - `:small` - 20:20:100 (for small problems)
221+ - `:medium` - 100:50:300 (for typical problems)
222+ - `:large` - 300:100:1000 (for larger problems)
223+ - `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
228224"""
229225function get_benchmark_sizes (size_categories:: Vector{Symbol} )
230226 sizes = Int[]
231-
227+
232228 for category in size_categories
233229 if category == :tiny
234230 append! (sizes, 5 : 5 : 20 )
@@ -244,7 +240,7 @@ function get_benchmark_sizes(size_categories::Vector{Symbol})
244240 @warn " Unknown size category: $category . Skipping."
245241 end
246242 end
247-
243+
248244 # Remove duplicates and sort
249245 return sort (unique (sizes))
250246end
@@ -281,10 +277,10 @@ function categorize_results(df::DataFrame)
281277
282278 for eltype in eltypes
283279 @info " Categorizing results for element type: $eltype "
284-
280+
285281 # Filter results for this element type
286282 eltype_df = filter (row -> row. eltype == eltype, successful_df)
287-
283+
288284 if nrow (eltype_df) == 0
289285 continue
290286 end
@@ -299,27 +295,24 @@ function categorize_results(df::DataFrame)
299295
300296 # Calculate average GFLOPs for each algorithm in this range
301297 avg_results = combine (groupby (range_df, :algorithm ), :gflops => mean => :avg_gflops )
302-
298+
303299 # Sort by performance
304- sort! (avg_results, :avg_gflops , rev = true )
300+ sort! (avg_results, :avg_gflops , rev= true )
305301
306302 # Find the best algorithm (for complex types, avoid RFLU if possible)
307303 if nrow (avg_results) > 0
308304 best_alg = avg_results. algorithm[1 ]
309-
305+
310306 # For complex types, check if best is RFLU and we have alternatives
311- if (eltype == " ComplexF32" || eltype == " ComplexF64" ) &&
312- (contains (best_alg, " RFLU" ) ||
313- contains (best_alg, " RecursiveFactorization" ))
314-
307+ if (eltype == " ComplexF32" || eltype == " ComplexF64" ) &&
308+ (contains (best_alg, " RFLU" ) || contains (best_alg, " RecursiveFactorization" ))
309+
315310 # Look for the best non-RFLU algorithm
316311 for i in 2 : nrow (avg_results)
317312 alt_alg = avg_results. algorithm[i]
318- if ! contains (alt_alg, " RFLU" ) &&
319- ! contains (alt_alg, " RecursiveFactorization" )
313+ if ! contains (alt_alg, " RFLU" ) && ! contains (alt_alg, " RecursiveFactorization" )
320314 # Check if performance difference is not too large (within 20%)
321- perf_ratio = avg_results. avg_gflops[i] /
322- avg_results. avg_gflops[1 ]
315+ perf_ratio = avg_results. avg_gflops[i] / avg_results. avg_gflops[1 ]
323316 if perf_ratio > 0.8
324317 @info " Using $alt_alg instead of $best_alg for $eltype at $range_name ($(round (100 * perf_ratio, digits= 1 )) % of RFLU performance) to avoid complex number issues"
325318 best_alg = alt_alg
@@ -330,7 +323,7 @@ function categorize_results(df::DataFrame)
330323 end
331324 end
332325 end
333-
326+
334327 category_key = " $(eltype) _$(range_name) "
335328 categories[category_key] = best_alg
336329 best_idx = findfirst (== (best_alg), avg_results. algorithm)
0 commit comments