Skip to content

Commit 875dd22

Browse files
Apply JuliaFormatter to modified files
1 parent 844beec commit 875dd22

File tree

4 files changed

+251
-216
lines changed

4 files changed

+251
-216
lines changed

format_autotune.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Pkg
2+
Pkg.add("JuliaFormatter")
3+
using JuliaFormatter
4+
5+
# Format only the changed files with SciMLStyle
6+
format("lib/LinearSolveAutotune/src/gpu_detection.jl", SciMLStyle())
7+
format("lib/LinearSolveAutotune/src/telemetry.jl", SciMLStyle())
8+
format("lib/LinearSolveAutotune/src/benchmarking.jl", SciMLStyle())
9+
10+
println("Formatting complete!")

lib/LinearSolveAutotune/src/benchmarking.jl

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,38 @@ Uses more strict rules for BLAS-dependent algorithms with non-standard types.
1313
function test_algorithm_compatibility(alg, eltype::Type, test_size::Int = 4)
1414
# Get algorithm name for type-specific compatibility rules
1515
alg_name = string(typeof(alg).name.name)
16-
16+
1717
# Define strict compatibility rules for BLAS-dependent algorithms
18-
if !(eltype <: LinearAlgebra.BLAS.BlasFloat) && alg_name in ["BLISFactorization", "MKLLUFactorization", "AppleAccelerateLUFactorization"]
18+
if !(eltype <: LinearAlgebra.BLAS.BlasFloat) && alg_name in [
19+
"BLISFactorization", "MKLLUFactorization", "AppleAccelerateLUFactorization"]
1920
return false # BLAS algorithms not compatible with non-standard types
2021
end
2122

2223
if alg_name == "BLISLUFactorization" && Sys.isapple()
2324
return false # BLISLUFactorization has no Apple Silicon binary
2425
end
25-
26+
2627
# For standard types or algorithms that passed the strict check, test functionality
2728
try
2829
# Create a small test problem with the specified element type
2930
rng = MersenneTwister(123)
3031
A = rand(rng, eltype, test_size, test_size)
3132
b = rand(rng, eltype, test_size)
3233
u0 = rand(rng, eltype, test_size)
33-
34+
3435
prob = LinearProblem(A, b; u0 = u0)
35-
36+
3637
# Try to solve - if it works, the algorithm is compatible
3738
sol = solve(prob, alg)
38-
39+
3940
# Additional check: verify the solution is actually of the expected type
4041
if !isa(sol.u, AbstractVector{eltype})
4142
@debug "Algorithm $alg_name returned wrong element type for $eltype"
4243
return false
4344
end
44-
45+
4546
return true
46-
47+
4748
catch e
4849
# Algorithm failed - not compatible with this element type
4950
@debug "Algorithm $alg_name failed for $eltype: $e"
@@ -60,14 +61,14 @@ Returns filtered algorithms and names.
6061
function filter_compatible_algorithms(algorithms, alg_names, eltype::Type)
6162
compatible_algs = []
6263
compatible_names = String[]
63-
64+
6465
for (alg, name) in zip(algorithms, alg_names)
6566
if test_algorithm_compatibility(alg, eltype)
6667
push!(compatible_algs, alg)
6768
push!(compatible_names, name)
6869
end
6970
end
70-
71+
7172
return compatible_algs, compatible_names
7273
end
7374

@@ -89,36 +90,37 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
8990

9091
# Initialize results DataFrame
9192
results_data = []
92-
93+
9394
# Calculate total number of benchmarks for progress bar
9495
total_benchmarks = 0
9596
for eltype in eltypes
9697
# Pre-filter to estimate the actual number
9798
test_algs, _ = filter_compatible_algorithms(algorithms, alg_names, eltype)
9899
total_benchmarks += length(matrix_sizes) * length(test_algs)
99100
end
100-
101+
101102
# Create progress bar
102-
progress = Progress(total_benchmarks, desc="Benchmarking: ",
103-
barlen=50, showspeed=true)
103+
progress = Progress(total_benchmarks, desc = "Benchmarking: ",
104+
barlen = 50, showspeed = true)
104105

105106
try
106107
for eltype in eltypes
107108
# Filter algorithms for this element type
108-
compatible_algs, compatible_names = filter_compatible_algorithms(algorithms, alg_names, eltype)
109-
109+
compatible_algs,
110+
compatible_names = filter_compatible_algorithms(algorithms, alg_names, eltype)
111+
110112
if isempty(compatible_algs)
111113
@warn "No algorithms compatible with $eltype, skipping..."
112114
continue
113115
end
114-
116+
115117
for n in matrix_sizes
116118
# Create test problem with specified element type
117119
rng = MersenneTwister(123) # Consistent seed for reproducibility
118120
A = rand(rng, eltype, n, n)
119121
b = rand(rng, eltype, n)
120122
u0 = rand(rng, eltype, n)
121-
123+
122124
# Compute reference solution with LUFactorization if correctness check is enabled
123125
reference_solution = nothing
124126
if check_correctness
@@ -133,9 +135,9 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
133135

134136
for (alg, name) in zip(compatible_algs, compatible_names)
135137
# Update progress description
136-
ProgressMeter.update!(progress,
137-
desc="Benchmarking $name on $(n)×$(n) $eltype matrix: ")
138-
138+
ProgressMeter.update!(progress,
139+
desc = "Benchmarking $name on $(n)×$(n) $eltype matrix: ")
140+
139141
gflops = 0.0
140142
success = true
141143
error_msg = ""
@@ -149,12 +151,13 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
149151

150152
# Warmup run and correctness check
151153
warmup_sol = solve(prob, alg)
152-
154+
153155
# Check correctness if reference solution is available
154156
if check_correctness && reference_solution !== nothing
155157
# Compute relative error
156-
rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u)
157-
158+
rel_error = norm(warmup_sol.u - reference_solution.u) /
159+
norm(reference_solution.u)
160+
158161
if rel_error > correctness_tol
159162
passed_correctness = false
160163
@warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " *
@@ -164,7 +167,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
164167
error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))"
165168
end
166169
end
167-
170+
168171
# Only benchmark if correctness check passed
169172
if passed_correctness
170173
# Actual benchmark
@@ -195,7 +198,7 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
195198
success = success,
196199
error = error_msg
197200
))
198-
201+
199202
# Update progress
200203
ProgressMeter.next!(progress)
201204
end
@@ -216,15 +219,16 @@ end
216219
Get the matrix sizes to benchmark based on the requested size categories.
217220
218221
Size categories:
219-
- `:tiny` - 5:5:20 (for very small problems)
220-
- `:small` - 20:20:100 (for small problems)
221-
- `:medium` - 100:50:300 (for typical problems)
222-
- `:large` - 300:100:1000 (for larger problems)
223-
- `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
222+
223+
- `:tiny` - 5:5:20 (for very small problems)
224+
- `:small` - 20:20:100 (for small problems)
225+
- `:medium` - 100:50:300 (for typical problems)
226+
- `:large` - 300:100:1000 (for larger problems)
227+
- `:big` - vcat(1000:2000:10000, 10000:5000:20000) (for very large/GPU problems)
224228
"""
225229
function get_benchmark_sizes(size_categories::Vector{Symbol})
226230
sizes = Int[]
227-
231+
228232
for category in size_categories
229233
if category == :tiny
230234
append!(sizes, 5:5:20)
@@ -240,7 +244,7 @@ function get_benchmark_sizes(size_categories::Vector{Symbol})
240244
@warn "Unknown size category: $category. Skipping."
241245
end
242246
end
243-
247+
244248
# Remove duplicates and sort
245249
return sort(unique(sizes))
246250
end
@@ -277,10 +281,10 @@ function categorize_results(df::DataFrame)
277281

278282
for eltype in eltypes
279283
@info "Categorizing results for element type: $eltype"
280-
284+
281285
# Filter results for this element type
282286
eltype_df = filter(row -> row.eltype == eltype, successful_df)
283-
287+
284288
if nrow(eltype_df) == 0
285289
continue
286290
end
@@ -295,24 +299,27 @@ function categorize_results(df::DataFrame)
295299

296300
# Calculate average GFLOPs for each algorithm in this range
297301
avg_results = combine(groupby(range_df, :algorithm), :gflops => mean => :avg_gflops)
298-
302+
299303
# Sort by performance
300-
sort!(avg_results, :avg_gflops, rev=true)
304+
sort!(avg_results, :avg_gflops, rev = true)
301305

302306
# Find the best algorithm (for complex types, avoid RFLU if possible)
303307
if nrow(avg_results) > 0
304308
best_alg = avg_results.algorithm[1]
305-
309+
306310
# For complex types, check if best is RFLU and we have alternatives
307-
if (eltype == "ComplexF32" || eltype == "ComplexF64") &&
308-
(contains(best_alg, "RFLU") || contains(best_alg, "RecursiveFactorization"))
309-
311+
if (eltype == "ComplexF32" || eltype == "ComplexF64") &&
312+
(contains(best_alg, "RFLU") ||
313+
contains(best_alg, "RecursiveFactorization"))
314+
310315
# Look for the best non-RFLU algorithm
311316
for i in 2:nrow(avg_results)
312317
alt_alg = avg_results.algorithm[i]
313-
if !contains(alt_alg, "RFLU") && !contains(alt_alg, "RecursiveFactorization")
318+
if !contains(alt_alg, "RFLU") &&
319+
!contains(alt_alg, "RecursiveFactorization")
314320
# Check if performance difference is not too large (within 20%)
315-
perf_ratio = avg_results.avg_gflops[i] / avg_results.avg_gflops[1]
321+
perf_ratio = avg_results.avg_gflops[i] /
322+
avg_results.avg_gflops[1]
316323
if perf_ratio > 0.8
317324
@info "Using $alt_alg instead of $best_alg for $eltype at $range_name ($(round(100*perf_ratio, digits=1))% of RFLU performance) to avoid complex number issues"
318325
best_alg = alt_alg
@@ -323,7 +330,7 @@ function categorize_results(df::DataFrame)
323330
end
324331
end
325332
end
326-
333+
327334
category_key = "$(eltype)_$(range_name)"
328335
categories[category_key] = best_alg
329336
best_idx = findfirst(==(best_alg), avg_results.algorithm)

0 commit comments

Comments
 (0)