Skip to content

Commit 5d564b6

Browse files
Revert "Apply JuliaFormatter to modified files"
This reverts commit 875dd22.
1 parent 84cc26c commit 5d564b6

File tree

3 files changed

+216
-241
lines changed

3 files changed

+216
-241
lines changed

lib/LinearSolveAutotune/src/benchmarking.jl

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,37 @@ Uses more strict rules for BLAS-dependent algorithms with non-standard types.
1313
function test_algorithm_compatibility(alg, eltype::Type, test_size::Int = 4)
1414
# Get algorithm name for type-specific compatibility rules
1515
alg_name = string(typeof(alg).name.name)
16-
16+
1717
# Define strict compatibility rules for BLAS-dependent algorithms
18-
if !(eltype <: LinearAlgebra.BLAS.BlasFloat) && alg_name in [
19-
"BLISFactorization", "MKLLUFactorization", "AppleAccelerateLUFactorization"]
18+
if !(eltype <: LinearAlgebra.BLAS.BlasFloat) && alg_name in ["BLISFactorization", "MKLLUFactorization", "AppleAccelerateLUFactorization"]
2019
return false # BLAS algorithms not compatible with non-standard types
2120
end
2221

2322
if alg_name == "BLISLUFactorization" && Sys.isapple()
2423
return false # BLISLUFactorization has no Apple Silicon binary
2524
end
26-
25+
2726
# For standard types or algorithms that passed the strict check, test functionality
2827
try
2928
# Create a small test problem with the specified element type
3029
rng = MersenneTwister(123)
3130
A = rand(rng, eltype, test_size, test_size)
3231
b = rand(rng, eltype, test_size)
3332
u0 = rand(rng, eltype, test_size)
34-
33+
3534
prob = LinearProblem(A, b; u0 = u0)
36-
35+
3736
# Try to solve - if it works, the algorithm is compatible
3837
sol = solve(prob, alg)
39-
38+
4039
# Additional check: verify the solution is actually of the expected type
4140
if !isa(sol.u, AbstractVector{eltype})
4241
@debug "Algorithm $alg_name returned wrong element type for $eltype"
4342
return false
4443
end
45-
44+
4645
return true
47-
46+
4847
catch e
4948
# Algorithm failed - not compatible with this element type
5049
@debug "Algorithm $alg_name failed for $eltype: $e"
@@ -61,14 +60,14 @@ Returns filtered algorithms and names.
6160
function filter_compatible_algorithms(algorithms, alg_names, eltype::Type)
6261
compatible_algs = []
6362
compatible_names = String[]
64-
63+
6564
for (alg, name) in zip(algorithms, alg_names)
6665
if test_algorithm_compatibility(alg, eltype)
6766
push!(compatible_algs, alg)
6867
push!(compatible_names, name)
6968
end
7069
end
71-
70+
7271
return compatible_algs, compatible_names
7372
end
7473

@@ -90,37 +89,36 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
9089

9190
# Initialize results DataFrame
9291
results_data = []
93-
92+
9493
# Calculate total number of benchmarks for progress bar
9594
total_benchmarks = 0
9695
for eltype in eltypes
9796
# Pre-filter to estimate the actual number
9897
test_algs, _ = filter_compatible_algorithms(algorithms, alg_names, eltype)
9998
total_benchmarks += length(matrix_sizes) * length(test_algs)
10099
end
101-
100+
102101
# Create progress bar
103-
progress = Progress(total_benchmarks, desc = "Benchmarking: ",
104-
barlen = 50, showspeed = true)
102+
progress = Progress(total_benchmarks, desc="Benchmarking: ",
103+
barlen=50, showspeed=true)
105104

106105
try
107106
for eltype in eltypes
108107
# Filter algorithms for this element type
109-
compatible_algs,
110-
compatible_names = filter_compatible_algorithms(algorithms, alg_names, eltype)
111-
108+
compatible_algs, compatible_names = filter_compatible_algorithms(algorithms, alg_names, eltype)
109+
112110
if isempty(compatible_algs)
113111
@warn "No algorithms compatible with $eltype, skipping..."
114112
continue
115113
end
116-
114+
117115
for n in matrix_sizes
118116
# Create test problem with specified element type
119117
rng = MersenneTwister(123) # Consistent seed for reproducibility
120118
A = rand(rng, eltype, n, n)
121119
b = rand(rng, eltype, n)
122120
u0 = rand(rng, eltype, n)
123-
121+
124122
# Compute reference solution with LUFactorization if correctness check is enabled
125123
reference_solution = nothing
126124
if check_correctness
@@ -135,9 +133,9 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
135133

136134
for (alg, name) in zip(compatible_algs, compatible_names)
137135
# Update progress description
138-
ProgressMeter.update!(progress,
139-
desc = "Benchmarking $name on $(n)×$(n) $eltype matrix: ")
140-
136+
ProgressMeter.update!(progress,
137+
desc="Benchmarking $name on $(n)×$(n) $eltype matrix: ")
138+
141139
gflops = 0.0
142140
success = true
143141
error_msg = ""
@@ -151,13 +149,12 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
151149

152150
# Warmup run and correctness check
153151
warmup_sol = solve(prob, alg)
154-
152+
155153
# Check correctness if reference solution is available
156154
if check_correctness && reference_solution !== nothing
157155
# Compute relative error
158-
rel_error = norm(warmup_sol.u - reference_solution.u) /
159-
norm(reference_solution.u)
160-
156+
rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u)
157+
161158
if rel_error > correctness_tol
162159
passed_correctness = false
163160
@warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " *
@@ -167,7 +164,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
167164
error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))"
168165
end
169166
end
170-
167+
171168
# Only benchmark if correctness check passed
172169
if passed_correctness
173170
# Actual benchmark
@@ -198,7 +195,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
198195
success = success,
199196
error = error_msg
200197
))
201-
198+
202199
# Update progress
203200
ProgressMeter.next!(progress)
204201
end
@@ -219,16 +216,15 @@ end
219216
Get the matrix sizes to benchmark based on the requested size categories.
220217
221218
Size categories:
222-
223-
- `:tiny` - 5:5:20 (for very small problems)
224-
- `:small` - 20:20:100 (for small problems)
225-
- `:medium` - 100:50:300 (for typical problems)
226-
- `:large` - 300:100:1000 (for larger problems)
227-
- `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
219+
- `:tiny` - 5:5:20 (for very small problems)
220+
- `:small` - 20:20:100 (for small problems)
221+
- `:medium` - 100:50:300 (for typical problems)
222+
- `:large` - 300:100:1000 (for larger problems)
223+
- `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
228224
"""
229225
function get_benchmark_sizes(size_categories::Vector{Symbol})
230226
sizes = Int[]
231-
227+
232228
for category in size_categories
233229
if category == :tiny
234230
append!(sizes, 5:5:20)
@@ -244,7 +240,7 @@ function get_benchmark_sizes(size_categories::Vector{Symbol})
244240
@warn "Unknown size category: $category. Skipping."
245241
end
246242
end
247-
243+
248244
# Remove duplicates and sort
249245
return sort(unique(sizes))
250246
end
@@ -281,10 +277,10 @@ function categorize_results(df::DataFrame)
281277

282278
for eltype in eltypes
283279
@info "Categorizing results for element type: $eltype"
284-
280+
285281
# Filter results for this element type
286282
eltype_df = filter(row -> row.eltype == eltype, successful_df)
287-
283+
288284
if nrow(eltype_df) == 0
289285
continue
290286
end
@@ -299,27 +295,24 @@ function categorize_results(df::DataFrame)
299295

300296
# Calculate average GFLOPs for each algorithm in this range
301297
avg_results = combine(groupby(range_df, :algorithm), :gflops => mean => :avg_gflops)
302-
298+
303299
# Sort by performance
304-
sort!(avg_results, :avg_gflops, rev = true)
300+
sort!(avg_results, :avg_gflops, rev=true)
305301

306302
# Find the best algorithm (for complex types, avoid RFLU if possible)
307303
if nrow(avg_results) > 0
308304
best_alg = avg_results.algorithm[1]
309-
305+
310306
# For complex types, check if best is RFLU and we have alternatives
311-
if (eltype == "ComplexF32" || eltype == "ComplexF64") &&
312-
(contains(best_alg, "RFLU") ||
313-
contains(best_alg, "RecursiveFactorization"))
314-
307+
if (eltype == "ComplexF32" || eltype == "ComplexF64") &&
308+
(contains(best_alg, "RFLU") || contains(best_alg, "RecursiveFactorization"))
309+
315310
# Look for the best non-RFLU algorithm
316311
for i in 2:nrow(avg_results)
317312
alt_alg = avg_results.algorithm[i]
318-
if !contains(alt_alg, "RFLU") &&
319-
!contains(alt_alg, "RecursiveFactorization")
313+
if !contains(alt_alg, "RFLU") && !contains(alt_alg, "RecursiveFactorization")
320314
# Check if performance difference is not too large (within 20%)
321-
perf_ratio = avg_results.avg_gflops[i] /
322-
avg_results.avg_gflops[1]
315+
perf_ratio = avg_results.avg_gflops[i] / avg_results.avg_gflops[1]
323316
if perf_ratio > 0.8
324317
@info "Using $alt_alg instead of $best_alg for $eltype at $range_name ($(round(100*perf_ratio, digits=1))% of RFLU performance) to avoid complex number issues"
325318
best_alg = alt_alg
@@ -330,7 +323,7 @@ function categorize_results(df::DataFrame)
330323
end
331324
end
332325
end
333-
326+
334327
category_key = "$(eltype)_$(range_name)"
335328
categories[category_key] = best_alg
336329
best_idx = findfirst(==(best_alg), avg_results.algorithm)

0 commit comments

Comments
 (0)