Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions lib/LinearSolveAutotune/src/LinearSolveAutotune.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ function Base.show(io::IO, results::AutotuneResults)
println(io, "📏 Matrix Sizes: ", minimum(sizes), "×", minimum(sizes),
" to ", maximum(sizes), "×", maximum(sizes))

# Report timeouts if any
timeout_results = filter(row -> isnan(row.gflops), results.results_df)
if nrow(timeout_results) > 0
println(io, "⏱️ Timed Out: ", nrow(timeout_results), " tests exceeded time limit")
# Report tests that exceeded maxtime if any
exceeded_results = filter(row -> isnan(row.gflops) && contains(get(row, :error, ""), "Exceeded maxtime"), results.results_df)
if nrow(exceeded_results) > 0
println(io, "⏱️ Exceeded maxtime: ", nrow(exceeded_results), " tests exceeded time limit")
end

# Call to action - reordered
Expand Down Expand Up @@ -265,17 +265,17 @@ function autotune_setup(;

# Display results table - filter out NaN values
successful_results = filter(row -> row.success && !isnan(row.gflops), results_df)
timeout_results = filter(row -> isnan(row.gflops) && !contains(get(row, :error, ""), "Skipped"), results_df)
exceeded_maxtime_results = filter(row -> isnan(row.gflops) && contains(get(row, :error, ""), "Exceeded maxtime"), results_df)
skipped_results = filter(row -> contains(get(row, :error, ""), "Skipped"), results_df)

if nrow(timeout_results) > 0
@info "$(nrow(timeout_results)) tests timed out (exceeded $(maxtime)s limit)"
if nrow(exceeded_maxtime_results) > 0
@info "$(nrow(exceeded_maxtime_results)) tests exceeded maxtime limit ($(maxtime)s)"
end

if nrow(skipped_results) > 0
# Count unique algorithms that were skipped
skipped_algs = unique([row.algorithm for row in eachrow(skipped_results)])
@info "$(length(skipped_algs)) algorithms skipped for larger matrices after timing out"
@info "$(length(skipped_algs)) algorithms skipped for larger matrices after exceeding maxtime"
end

if nrow(successful_results) > 0
Expand Down
120 changes: 40 additions & 80 deletions lib/LinearSolveAutotune/src/benchmarking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
# Initialize results DataFrame
results_data = []

# Track algorithms that have timed out (per element type)
timed_out_algorithms = Dict{String, Set{String}}() # eltype => Set of algorithm names
# Track algorithms that have exceeded maxtime (per element type and size)
# Structure: eltype => algorithm_name => max_size_tested
blocked_algorithms = Dict{String, Dict{String, Int}}() # eltype => Dict(algorithm_name => max_size)

# Calculate total number of benchmarks for progress bar
total_benchmarks = 0
Expand All @@ -112,8 +113,8 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;

try
for eltype in eltypes
# Initialize timed out set for this element type
timed_out_algorithms[string(eltype)] = Set{String}()
# Initialize blocked algorithms dict for this element type
blocked_algorithms[string(eltype)] = Dict{String, Int}()

# Filter algorithms for this element type
compatible_algs, compatible_names = filter_compatible_algorithms(algorithms, alg_names, eltype)
Expand Down Expand Up @@ -143,32 +144,35 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
end

for (alg, name) in zip(compatible_algs, compatible_names)
# Skip this algorithm if it has already timed out for this element type
if name in timed_out_algorithms[string(eltype)]
# Still need to update progress bar
ProgressMeter.next!(progress)
# Record as skipped due to previous timeout
push!(results_data,
(
size = n,
algorithm = name,
eltype = string(eltype),
gflops = NaN,
success = false,
error = "Skipped: timed out on smaller matrix"
))
continue
# Skip this algorithm if it has exceeded maxtime for a smaller or equal size matrix
if haskey(blocked_algorithms[string(eltype)], name)
max_allowed_size = blocked_algorithms[string(eltype)][name]
if n > max_allowed_size
# Still need to update progress bar
ProgressMeter.next!(progress)
# Record as skipped due to exceeding maxtime on smaller matrix
push!(results_data,
(
size = n,
algorithm = name,
eltype = string(eltype),
gflops = NaN,
success = false,
error = "Skipped: exceeded maxtime on size $max_allowed_size matrix"
))
continue
end
end

# Update progress description
ProgressMeter.update!(progress,
desc="Benchmarking $name on $(n)×$(n) $eltype matrix: ")

gflops = NaN # Use NaN for timed out runs
gflops = NaN # Use NaN for failed/timed out runs
success = true
error_msg = ""
passed_correctness = true
timed_out = false
exceeded_maxtime = false

try
# Create the linear problem for this test
Expand All @@ -179,69 +183,25 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
# Time the warmup run and correctness check
start_time = time()

# Warmup run and correctness check with simple timeout
# Warmup run and correctness check - no interruption, just timing
warmup_sol = nothing
timed_out_flag = false

# Try to run with a timeout - simpler approach without async
try
# Create a task for the solve
done_channel = Channel(1)
error_channel = Channel(1)

warmup_task = @async begin
try
result = solve(prob, alg)
put!(done_channel, result)
catch e
put!(error_channel, e)
end
end

# Wait for completion or timeout
timeout_occurred = false
result = nothing

# Use timedwait which is more reliable than manual polling
wait_result = timedwait(() -> istaskdone(warmup_task), maxtime)

if wait_result === :timed_out
timeout_occurred = true
timed_out_flag = true
# Don't try to kill the task - just mark it as timed out
# The task will continue running in background but we move on
else
# Task completed - get the result
if isready(done_channel)
warmup_sol = take!(done_channel)
elseif isready(error_channel)
throw(take!(error_channel))
end
end

# Close channels to prevent resource leaks
close(done_channel)
close(error_channel)

catch e
# If an error occurred during solve, re-throw it
if !timed_out_flag
throw(e)
end
end
# Simply run the solve and measure time
warmup_sol = solve(prob, alg)
elapsed_time = time() - start_time

if timed_out_flag
# Task timed out
timed_out = true
# Add to timed out set so it's skipped for larger matrices
push!(timed_out_algorithms[string(eltype)], name)
@warn "Algorithm $name timed out (exceeded $(maxtime)s) for size $n, eltype $eltype. Will skip for larger matrices."
# Check if we exceeded maxtime
if elapsed_time > maxtime
exceeded_maxtime = true
# Block this algorithm for larger matrices
# Store the last size that was allowed to complete
blocked_algorithms[string(eltype)][name] = n
@warn "Algorithm $name exceeded maxtime ($(round(elapsed_time, digits=2))s > $(maxtime)s) for size $n, eltype $eltype. Will skip for larger matrices."
success = false
error_msg = "Timed out (exceeded $(maxtime)s)"
error_msg = "Exceeded maxtime ($(round(elapsed_time, digits=2))s)"
gflops = NaN
else
# Successful completion
elapsed_time = time() - start_time
# Successful completion within time limit

# Check correctness if reference solution is available
if check_correctness && reference_solution !== nothing
Expand All @@ -259,8 +219,8 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
end
end

# Only benchmark if correctness check passed and we have time remaining
if passed_correctness && !timed_out
# Only benchmark if correctness check passed and we didn't exceed maxtime
if passed_correctness && !exceeded_maxtime
# Check if we have enough time remaining for benchmarking
# Allow at least 2x the warmup time for benchmarking
remaining_time = maxtime - elapsed_time
Expand Down
Loading