Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolveAutotune = "67398393-80e8-4254-b7e4-1b9a36a3c5b6"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand Down
16 changes: 8 additions & 8 deletions lib/LinearSolveAutotune/src/LinearSolveAutotune.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ function Base.show(io::IO, results::AutotuneResults)
println(io, "📏 Matrix Sizes: ", minimum(sizes), "×", minimum(sizes),
" to ", maximum(sizes), "×", maximum(sizes))

# Report timeouts if any
timeout_results = filter(row -> isnan(row.gflops), results.results_df)
if nrow(timeout_results) > 0
println(io, "⏱️ Timed Out: ", nrow(timeout_results), " tests exceeded time limit")
# Report tests that exceeded maxtime if any
exceeded_results = filter(row -> isnan(row.gflops) && contains(get(row, :error, ""), "Exceeded maxtime"), results.results_df)
if nrow(exceeded_results) > 0
println(io, "⏱️ Exceeded maxtime: ", nrow(exceeded_results), " tests exceeded time limit")
end

# Call to action - reordered
Expand Down Expand Up @@ -265,17 +265,17 @@ function autotune_setup(;

# Display results table - filter out NaN values
successful_results = filter(row -> row.success && !isnan(row.gflops), results_df)
timeout_results = filter(row -> isnan(row.gflops) && !contains(get(row, :error, ""), "Skipped"), results_df)
exceeded_maxtime_results = filter(row -> isnan(row.gflops) && contains(get(row, :error, ""), "Exceeded maxtime"), results_df)
skipped_results = filter(row -> contains(get(row, :error, ""), "Skipped"), results_df)

if nrow(timeout_results) > 0
@info "$(nrow(timeout_results)) tests timed out (exceeded $(maxtime)s limit)"
if nrow(exceeded_maxtime_results) > 0
@info "$(nrow(exceeded_maxtime_results)) tests exceeded maxtime limit ($(maxtime)s)"
end

if nrow(skipped_results) > 0
# Count unique algorithms that were skipped
skipped_algs = unique([row.algorithm for row in eachrow(skipped_results)])
@info "$(length(skipped_algs)) algorithms skipped for larger matrices after timing out"
@info "$(length(skipped_algs)) algorithms skipped for larger matrices after exceeding maxtime"
end

if nrow(successful_results) > 0
Expand Down
120 changes: 40 additions & 80 deletions lib/LinearSolveAutotune/src/benchmarking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
# Initialize results DataFrame
results_data = []

# Track algorithms that have timed out (per element type)
timed_out_algorithms = Dict{String, Set{String}}() # eltype => Set of algorithm names
# Track algorithms that have exceeded maxtime (per element type and size)
# Structure: eltype => algorithm_name => max_size_tested
blocked_algorithms = Dict{String, Dict{String, Int}}() # eltype => Dict(algorithm_name => max_size)

# Calculate total number of benchmarks for progress bar
total_benchmarks = 0
Expand All @@ -112,8 +113,8 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;

try
for eltype in eltypes
# Initialize timed out set for this element type
timed_out_algorithms[string(eltype)] = Set{String}()
# Initialize blocked algorithms dict for this element type
blocked_algorithms[string(eltype)] = Dict{String, Int}()

# Filter algorithms for this element type
compatible_algs, compatible_names = filter_compatible_algorithms(algorithms, alg_names, eltype)
Expand Down Expand Up @@ -143,32 +144,35 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
end

for (alg, name) in zip(compatible_algs, compatible_names)
# Skip this algorithm if it has already timed out for this element type
if name in timed_out_algorithms[string(eltype)]
# Still need to update progress bar
ProgressMeter.next!(progress)
# Record as skipped due to previous timeout
push!(results_data,
(
size = n,
algorithm = name,
eltype = string(eltype),
gflops = NaN,
success = false,
error = "Skipped: timed out on smaller matrix"
))
continue
# Skip this algorithm if it has exceeded maxtime for a smaller or equal size matrix
if haskey(blocked_algorithms[string(eltype)], name)
max_allowed_size = blocked_algorithms[string(eltype)][name]
if n > max_allowed_size
# Still need to update progress bar
ProgressMeter.next!(progress)
# Record as skipped due to exceeding maxtime on smaller matrix
push!(results_data,
(
size = n,
algorithm = name,
eltype = string(eltype),
gflops = NaN,
success = false,
error = "Skipped: exceeded maxtime on size $max_allowed_size matrix"
))
continue
end
end

# Update progress description
ProgressMeter.update!(progress,
desc="Benchmarking $name on $(n)×$(n) $eltype matrix: ")

gflops = NaN # Use NaN for timed out runs
gflops = NaN # Use NaN for failed/timed out runs
success = true
error_msg = ""
passed_correctness = true
timed_out = false
exceeded_maxtime = false

try
# Create the linear problem for this test
Expand All @@ -179,69 +183,25 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
# Time the warmup run and correctness check
start_time = time()

# Warmup run and correctness check with simple timeout
# Warmup run and correctness check - no interruption, just timing
warmup_sol = nothing
timed_out_flag = false

# Try to run with a timeout - simpler approach without async
try
# Create a task for the solve
done_channel = Channel(1)
error_channel = Channel(1)

warmup_task = @async begin
try
result = solve(prob, alg)
put!(done_channel, result)
catch e
put!(error_channel, e)
end
end

# Wait for completion or timeout
timeout_occurred = false
result = nothing

# Use timedwait which is more reliable than manual polling
wait_result = timedwait(() -> istaskdone(warmup_task), maxtime)

if wait_result === :timed_out
timeout_occurred = true
timed_out_flag = true
# Don't try to kill the task - just mark it as timed out
# The task will continue running in background but we move on
else
# Task completed - get the result
if isready(done_channel)
warmup_sol = take!(done_channel)
elseif isready(error_channel)
throw(take!(error_channel))
end
end

# Close channels to prevent resource leaks
close(done_channel)
close(error_channel)

catch e
# If an error occurred during solve, re-throw it
if !timed_out_flag
throw(e)
end
end
# Simply run the solve and measure time
warmup_sol = solve(prob, alg)
elapsed_time = time() - start_time

if timed_out_flag
# Task timed out
timed_out = true
# Add to timed out set so it's skipped for larger matrices
push!(timed_out_algorithms[string(eltype)], name)
@warn "Algorithm $name timed out (exceeded $(maxtime)s) for size $n, eltype $eltype. Will skip for larger matrices."
# Check if we exceeded maxtime
if elapsed_time > maxtime
exceeded_maxtime = true
# Block this algorithm for larger matrices
# Store the last size that was allowed to complete
blocked_algorithms[string(eltype)][name] = n
@warn "Algorithm $name exceeded maxtime ($(round(elapsed_time, digits=2))s > $(maxtime)s) for size $n, eltype $eltype. Will skip for larger matrices."
success = false
error_msg = "Timed out (exceeded $(maxtime)s)"
error_msg = "Exceeded maxtime ($(round(elapsed_time, digits=2))s)"
gflops = NaN
else
# Successful completion
elapsed_time = time() - start_time
# Successful completion within time limit

# Check correctness if reference solution is available
if check_correctness && reference_solution !== nothing
Expand All @@ -259,8 +219,8 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
end
end

# Only benchmark if correctness check passed and we have time remaining
if passed_correctness && !timed_out
# Only benchmark if correctness check passed and we didn't exceed maxtime
if passed_correctness && !exceeded_maxtime
# Check if we have enough time remaining for benchmarking
# Allow at least 2x the warmup time for benchmarking
remaining_time = maxtime - elapsed_time
Expand Down
Loading