@@ -13,38 +13,37 @@ Uses more strict rules for BLAS-dependent algorithms with non-standard types.
13
13
function test_algorithm_compatibility (alg, eltype:: Type , test_size:: Int = 4 )
14
14
# Get algorithm name for type-specific compatibility rules
15
15
alg_name = string (typeof (alg). name. name)
16
-
16
+
17
17
# Define strict compatibility rules for BLAS-dependent algorithms
18
- if ! (eltype <: LinearAlgebra.BLAS.BlasFloat ) && alg_name in [
19
- " BLISFactorization" , " MKLLUFactorization" , " AppleAccelerateLUFactorization" ]
18
+ if ! (eltype <: LinearAlgebra.BLAS.BlasFloat ) && alg_name in [" BLISFactorization" , " MKLLUFactorization" , " AppleAccelerateLUFactorization" ]
20
19
return false # BLAS algorithms not compatible with non-standard types
21
20
end
22
21
23
22
if alg_name == " BLISLUFactorization" && Sys. isapple ()
24
23
return false # BLISLUFactorization has no Apple Silicon binary
25
24
end
26
-
25
+
27
26
# For standard types or algorithms that passed the strict check, test functionality
28
27
try
29
28
# Create a small test problem with the specified element type
30
29
rng = MersenneTwister (123 )
31
30
A = rand (rng, eltype, test_size, test_size)
32
31
b = rand (rng, eltype, test_size)
33
32
u0 = rand (rng, eltype, test_size)
34
-
33
+
35
34
prob = LinearProblem (A, b; u0 = u0)
36
-
35
+
37
36
# Try to solve - if it works, the algorithm is compatible
38
37
sol = solve (prob, alg)
39
-
38
+
40
39
# Additional check: verify the solution is actually of the expected type
41
40
if ! isa (sol. u, AbstractVector{eltype})
42
41
@debug " Algorithm $alg_name returned wrong element type for $eltype "
43
42
return false
44
43
end
45
-
44
+
46
45
return true
47
-
46
+
48
47
catch e
49
48
# Algorithm failed - not compatible with this element type
50
49
@debug " Algorithm $alg_name failed for $eltype : $e "
@@ -61,14 +60,14 @@ Returns filtered algorithms and names.
61
60
function filter_compatible_algorithms (algorithms, alg_names, eltype:: Type )
62
61
compatible_algs = []
63
62
compatible_names = String[]
64
-
63
+
65
64
for (alg, name) in zip (algorithms, alg_names)
66
65
if test_algorithm_compatibility (alg, eltype)
67
66
push! (compatible_algs, alg)
68
67
push! (compatible_names, name)
69
68
end
70
69
end
71
-
70
+
72
71
return compatible_algs, compatible_names
73
72
end
74
73
@@ -90,37 +89,36 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
90
89
91
90
# Initialize results DataFrame
92
91
results_data = []
93
-
92
+
94
93
# Calculate total number of benchmarks for progress bar
95
94
total_benchmarks = 0
96
95
for eltype in eltypes
97
96
# Pre-filter to estimate the actual number
98
97
test_algs, _ = filter_compatible_algorithms (algorithms, alg_names, eltype)
99
98
total_benchmarks += length (matrix_sizes) * length (test_algs)
100
99
end
101
-
100
+
102
101
# Create progress bar
103
- progress = Progress (total_benchmarks, desc = " Benchmarking: " ,
104
- barlen = 50 , showspeed = true )
102
+ progress = Progress (total_benchmarks, desc= " Benchmarking: " ,
103
+ barlen = 50 , showspeed= true )
105
104
106
105
try
107
106
for eltype in eltypes
108
107
# Filter algorithms for this element type
109
- compatible_algs,
110
- compatible_names = filter_compatible_algorithms (algorithms, alg_names, eltype)
111
-
108
+ compatible_algs, compatible_names = filter_compatible_algorithms (algorithms, alg_names, eltype)
109
+
112
110
if isempty (compatible_algs)
113
111
@warn " No algorithms compatible with $eltype , skipping..."
114
112
continue
115
113
end
116
-
114
+
117
115
for n in matrix_sizes
118
116
# Create test problem with specified element type
119
117
rng = MersenneTwister (123 ) # Consistent seed for reproducibility
120
118
A = rand (rng, eltype, n, n)
121
119
b = rand (rng, eltype, n)
122
120
u0 = rand (rng, eltype, n)
123
-
121
+
124
122
# Compute reference solution with LUFactorization if correctness check is enabled
125
123
reference_solution = nothing
126
124
if check_correctness
@@ -135,9 +133,9 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
135
133
136
134
for (alg, name) in zip (compatible_algs, compatible_names)
137
135
# Update progress description
138
- ProgressMeter. update! (progress,
139
- desc = " Benchmarking $name on $(n) ×$(n) $eltype matrix: " )
140
-
136
+ ProgressMeter. update! (progress,
137
+ desc= " Benchmarking $name on $(n) ×$(n) $eltype matrix: " )
138
+
141
139
gflops = 0.0
142
140
success = true
143
141
error_msg = " "
@@ -151,13 +149,12 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
151
149
152
150
# Warmup run and correctness check
153
151
warmup_sol = solve (prob, alg)
154
-
152
+
155
153
# Check correctness if reference solution is available
156
154
if check_correctness && reference_solution != = nothing
157
155
# Compute relative error
158
- rel_error = norm (warmup_sol. u - reference_solution. u) /
159
- norm (reference_solution. u)
160
-
156
+ rel_error = norm (warmup_sol. u - reference_solution. u) / norm (reference_solution. u)
157
+
161
158
if rel_error > correctness_tol
162
159
passed_correctness = false
163
160
@warn " Algorithm $name failed correctness check for size $n , eltype $eltype . " *
@@ -167,7 +164,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
167
164
error_msg = " Failed correctness check (rel_error = $(round (rel_error, sigdigits= 3 )) )"
168
165
end
169
166
end
170
-
167
+
171
168
# Only benchmark if correctness check passed
172
169
if passed_correctness
173
170
# Actual benchmark
@@ -198,7 +195,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
198
195
success = success,
199
196
error = error_msg
200
197
))
201
-
198
+
202
199
# Update progress
203
200
ProgressMeter. next! (progress)
204
201
end
@@ -219,16 +216,15 @@ end
219
216
Get the matrix sizes to benchmark based on the requested size categories.
220
217
221
218
Size categories:
222
-
223
- - `:tiny` - 5:5:20 (for very small problems)
224
- - `:small` - 20:20:100 (for small problems)
225
- - `:medium` - 100:50:300 (for typical problems)
226
- - `:large` - 300:100:1000 (for larger problems)
227
- - `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
219
+ - `:tiny` - 5:5:20 (for very small problems)
220
+ - `:small` - 20:20:100 (for small problems)
221
+ - `:medium` - 100:50:300 (for typical problems)
222
+ - `:large` - 300:100:1000 (for larger problems)
223
+ - `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
228
224
"""
229
225
function get_benchmark_sizes (size_categories:: Vector{Symbol} )
230
226
sizes = Int[]
231
-
227
+
232
228
for category in size_categories
233
229
if category == :tiny
234
230
append! (sizes, 5 : 5 : 20 )
@@ -244,7 +240,7 @@ function get_benchmark_sizes(size_categories::Vector{Symbol})
244
240
@warn " Unknown size category: $category . Skipping."
245
241
end
246
242
end
247
-
243
+
248
244
# Remove duplicates and sort
249
245
return sort (unique (sizes))
250
246
end
@@ -281,10 +277,10 @@ function categorize_results(df::DataFrame)
281
277
282
278
for eltype in eltypes
283
279
@info " Categorizing results for element type: $eltype "
284
-
280
+
285
281
# Filter results for this element type
286
282
eltype_df = filter (row -> row. eltype == eltype, successful_df)
287
-
283
+
288
284
if nrow (eltype_df) == 0
289
285
continue
290
286
end
@@ -299,27 +295,24 @@ function categorize_results(df::DataFrame)
299
295
300
296
# Calculate average GFLOPs for each algorithm in this range
301
297
avg_results = combine (groupby (range_df, :algorithm ), :gflops => mean => :avg_gflops )
302
-
298
+
303
299
# Sort by performance
304
- sort! (avg_results, :avg_gflops , rev = true )
300
+ sort! (avg_results, :avg_gflops , rev= true )
305
301
306
302
# Find the best algorithm (for complex types, avoid RFLU if possible)
307
303
if nrow (avg_results) > 0
308
304
best_alg = avg_results. algorithm[1 ]
309
-
305
+
310
306
# For complex types, check if best is RFLU and we have alternatives
311
- if (eltype == " ComplexF32" || eltype == " ComplexF64" ) &&
312
- (contains (best_alg, " RFLU" ) ||
313
- contains (best_alg, " RecursiveFactorization" ))
314
-
307
+ if (eltype == " ComplexF32" || eltype == " ComplexF64" ) &&
308
+ (contains (best_alg, " RFLU" ) || contains (best_alg, " RecursiveFactorization" ))
309
+
315
310
# Look for the best non-RFLU algorithm
316
311
for i in 2 : nrow (avg_results)
317
312
alt_alg = avg_results. algorithm[i]
318
- if ! contains (alt_alg, " RFLU" ) &&
319
- ! contains (alt_alg, " RecursiveFactorization" )
313
+ if ! contains (alt_alg, " RFLU" ) && ! contains (alt_alg, " RecursiveFactorization" )
320
314
# Check if performance difference is not too large (within 20%)
321
- perf_ratio = avg_results. avg_gflops[i] /
322
- avg_results. avg_gflops[1 ]
315
+ perf_ratio = avg_results. avg_gflops[i] / avg_results. avg_gflops[1 ]
323
316
if perf_ratio > 0.8
324
317
@info " Using $alt_alg instead of $best_alg for $eltype at $range_name ($(round (100 * perf_ratio, digits= 1 )) % of RFLU performance) to avoid complex number issues"
325
318
best_alg = alt_alg
@@ -330,7 +323,7 @@ function categorize_results(df::DataFrame)
330
323
end
331
324
end
332
325
end
333
-
326
+
334
327
category_key = " $(eltype) _$(range_name) "
335
328
categories[category_key] = best_alg
336
329
best_idx = findfirst (== (best_alg), avg_results. algorithm)
0 commit comments