Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions docs/src/tutorials/autotune.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,37 @@ results = autotune_setup(
)
```

### Time Limits for Algorithm Tests

Control the maximum time allowed for each algorithm test (including accuracy check):

```julia
# Default: 100 seconds maximum per algorithm test
results = autotune_setup() # maxtime = 100.0

# Quick timeout for fast exploration
results = autotune_setup(maxtime = 10.0)

# Extended timeout for slow algorithms or large matrices
results = autotune_setup(
maxtime = 300.0, # 5 minutes per test
sizes = [:large, :big]
)

# Conservative timeout for production benchmarking
results = autotune_setup(
maxtime = 200.0,
samples = 10,
seconds = 2.0
)
```

When an algorithm exceeds the `maxtime` limit:
- The test is skipped to prevent hanging
- The result is recorded as `NaN` in the benchmark data
- A warning is displayed indicating the timeout
- The benchmark continues with the next algorithm

### Missing Algorithm Handling

By default, autotune expects all algorithms to be available to ensure complete benchmarking. You can relax this requirement:
Expand Down
11 changes: 8 additions & 3 deletions lib/LinearSolveAutotune/src/LinearSolveAutotune.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ end
seconds::Float64 = 0.5,
eltypes = (Float32, Float64, ComplexF32, ComplexF64),
skip_missing_algs::Bool = false,
include_fastlapack::Bool = false)
include_fastlapack::Bool = false,
maxtime::Float64 = 100.0)

Run a comprehensive benchmark of all available LU factorization methods and optionally:

Expand All @@ -182,6 +183,8 @@ Run a comprehensive benchmark of all available LU factorization methods and opti
- `eltypes = (Float32, Float64, ComplexF32, ComplexF64)`: Element types to benchmark
- `skip_missing_algs::Bool = false`: If false, error when expected algorithms are missing; if true, warn instead
- `include_fastlapack::Bool = false`: If true, includes FastLUFactorization in benchmarks
- `maxtime::Float64 = 100.0`: Maximum time in seconds for each algorithm test (including accuracy check).
If exceeded, the run is skipped and recorded as NaN

# Returns

Expand Down Expand Up @@ -216,7 +219,8 @@ function autotune_setup(;
seconds::Float64 = 0.5,
eltypes = (Float64,),
skip_missing_algs::Bool = false,
include_fastlapack::Bool = false)
include_fastlapack::Bool = false,
maxtime::Float64 = 100.0)
@info "Starting LinearSolve.jl autotune setup..."
@info "Configuration: sizes=$sizes, set_preferences=$set_preferences"
@info "Element types to benchmark: $(join(eltypes, ", "))"
Expand Down Expand Up @@ -249,8 +253,9 @@ function autotune_setup(;

# Run benchmarks
@info "Running benchmarks (this may take several minutes)..."
@info "Maximum time per algorithm test: $(maxtime)s"
results_df = benchmark_algorithms(matrix_sizes, all_algs, all_names, eltypes;
samples = samples, seconds = seconds, sizes = sizes)
samples = samples, seconds = seconds, sizes = sizes, maxtime = maxtime)

# Display results table
successful_results = filter(row -> row.success, results_df)
Expand Down
131 changes: 102 additions & 29 deletions lib/LinearSolveAutotune/src/benchmarking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ end

"""
benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
samples=5, seconds=0.5, sizes=[:small, :medium])
samples=5, seconds=0.5, sizes=[:small, :medium],
maxtime=100.0)

Benchmark the given algorithms across different matrix sizes and element types.
Returns a DataFrame with results including element type information.

# Arguments
- `maxtime::Float64 = 100.0`: Maximum time in seconds for each algorithm test (including accuracy check).
If the accuracy check exceeds this time, the run is skipped and recorded as NaN.
"""
function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
samples = 5, seconds = 0.5, sizes = [:tiny, :small, :medium, :large],
check_correctness = true, correctness_tol = 1e0)
check_correctness = true, correctness_tol = 1e0, maxtime = 100.0)

# Set benchmark parameters
old_params = BenchmarkTools.DEFAULT_PARAMETERS
Expand Down Expand Up @@ -136,52 +141,120 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
ProgressMeter.update!(progress,
desc="Benchmarking $name on $(n)×$(n) $eltype matrix: ")

gflops = 0.0
gflops = NaN # Use NaN for timed out runs
success = true
error_msg = ""
passed_correctness = true
timed_out = false

try
# Create the linear problem for this test
prob = LinearProblem(copy(A), copy(b);
u0 = copy(u0),
alias = LinearAliasSpecifier(alias_A = true, alias_b = true))

# Warmup run and correctness check
warmup_sol = solve(prob, alg)
# Time the warmup run and correctness check
start_time = time()

# Check correctness if reference solution is available
if check_correctness && reference_solution !== nothing
# Compute relative error
rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u)

if rel_error > correctness_tol
passed_correctness = false
@warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " *
"Relative error: $(round(rel_error, sigdigits=3)) > tolerance: $correctness_tol. " *
"Algorithm will be excluded from results."
success = false
error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))"
# Create a channel for communication between tasks
result_channel = Channel(1)

# Warmup run and correctness check with timeout
warmup_task = @async begin
try
result = solve(prob, alg)
put!(result_channel, result)
catch e
put!(result_channel, e)
end
end

# Timer task to enforce timeout
timer_task = @async begin
sleep(maxtime)
if !istaskdone(warmup_task)
try
Base.throwto(warmup_task, InterruptException())
catch
# Task might be in non-interruptible state
end
put!(result_channel, :timeout)
end
end

# Wait for result or timeout
warmup_sol = nothing
result = take!(result_channel)

# Clean up timer task if still running
if !istaskdone(timer_task)
try
Base.throwto(timer_task, InterruptException())
catch
# Timer task might have already finished
end
end

# Only benchmark if correctness check passed
if passed_correctness
# Actual benchmark
bench = @benchmark solve($prob, $alg) setup=(prob = LinearProblem(
copy($A), copy($b);
u0 = copy($u0),
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))

# Calculate GFLOPs
min_time_sec = minimum(bench.times) / 1e9
flops = luflop(n, n)
gflops = flops / min_time_sec / 1e9
if result === :timeout
# Task timed out
timed_out = true
@warn "Algorithm $name timed out (exceeded $(maxtime)s) for size $n, eltype $eltype. Recording as NaN."
success = false
error_msg = "Timed out (exceeded $(maxtime)s)"
gflops = NaN
elseif result isa Exception
# Task threw an error
throw(result)
else
# Successful completion
warmup_sol = result
elapsed_time = time() - start_time

# Check correctness if reference solution is available
if check_correctness && reference_solution !== nothing
# Compute relative error
rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u)

if rel_error > correctness_tol
passed_correctness = false
@warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " *
"Relative error: $(round(rel_error, sigdigits=3)) > tolerance: $correctness_tol. " *
"Algorithm will be excluded from results."
success = false
error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))"
gflops = 0.0
end
end

# Only benchmark if correctness check passed and we have time remaining
if passed_correctness && !timed_out
# Check if we have enough time remaining for benchmarking
# Allow at least 2x the warmup time for benchmarking
remaining_time = maxtime - elapsed_time
if remaining_time < 2 * elapsed_time
@warn "Algorithm $name: insufficient time remaining for benchmarking (warmup took $(round(elapsed_time, digits=2))s). Recording as NaN."
gflops = NaN
success = false
error_msg = "Insufficient time for benchmarking"
else
# Actual benchmark
bench = @benchmark solve($prob, $alg) setup=(prob = LinearProblem(
copy($A), copy($b);
u0 = copy($u0),
alias = LinearAliasSpecifier(alias_A = true, alias_b = true)))

# Calculate GFLOPs
min_time_sec = minimum(bench.times) / 1e9
flops = luflop(n, n)
gflops = flops / min_time_sec / 1e9
end
end
end

catch e
success = false
error_msg = string(e)
gflops = 0.0
# Don't warn for each failure, just record it
end

Expand Down
Loading