|
73 | 73 |
|
74 | 74 | """
|
75 | 75 | benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
|
76 |
| - samples=5, seconds=0.5, sizes=[:small, :medium]) |
| 76 | + samples=5, seconds=0.5, sizes=[:small, :medium], |
| 77 | + maxtime=100.0) |
77 | 78 |
|
78 | 79 | Benchmark the given algorithms across different matrix sizes and element types.
|
79 | 80 | Returns a DataFrame with results including element type information.
|
| 81 | +
|
| 82 | +# Arguments |
| 83 | +- `maxtime::Float64 = 100.0`: Maximum time in seconds for each algorithm test (including accuracy check). |
| 84 | + If the accuracy check exceeds this time, the run is skipped and recorded as NaN. |
80 | 85 | """
|
81 | 86 | function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
|
82 | 87 | samples = 5, seconds = 0.5, sizes = [:tiny, :small, :medium, :large],
|
83 |
| - check_correctness = true, correctness_tol = 1e0) |
| 88 | + check_correctness = true, correctness_tol = 1e0, maxtime = 100.0) |
84 | 89 |
|
85 | 90 | # Set benchmark parameters
|
86 | 91 | old_params = BenchmarkTools.DEFAULT_PARAMETERS
|
@@ -136,52 +141,120 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
|
136 | 141 | ProgressMeter.update!(progress,
|
137 | 142 | desc="Benchmarking $name on $(n)×$(n) $eltype matrix: ")
|
138 | 143 |
|
139 |
| - gflops = 0.0 |
| 144 | + gflops = NaN # Use NaN for timed out runs |
140 | 145 | success = true
|
141 | 146 | error_msg = ""
|
142 | 147 | passed_correctness = true
|
| 148 | + timed_out = false |
143 | 149 |
|
144 | 150 | try
|
145 | 151 | # Create the linear problem for this test
|
146 | 152 | prob = LinearProblem(copy(A), copy(b);
|
147 | 153 | u0 = copy(u0),
|
148 | 154 | alias = LinearAliasSpecifier(alias_A = true, alias_b = true))
|
149 | 155 |
|
150 |
| - # Warmup run and correctness check |
151 |
| - warmup_sol = solve(prob, alg) |
| 156 | + # Time the warmup run and correctness check |
| 157 | + start_time = time() |
152 | 158 |
|
153 |
| - # Check correctness if reference solution is available |
154 |
| - if check_correctness && reference_solution !== nothing |
155 |
| - # Compute relative error |
156 |
| - rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u) |
157 |
| - |
158 |
| - if rel_error > correctness_tol |
159 |
| - passed_correctness = false |
160 |
| - @warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " * |
161 |
| - "Relative error: $(round(rel_error, sigdigits=3)) > tolerance: $correctness_tol. " * |
162 |
| - "Algorithm will be excluded from results." |
163 |
| - success = false |
164 |
| - error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))" |
| 159 | + # Create a channel for communication between tasks |
| 160 | + result_channel = Channel(1) |
| 161 | + |
| 162 | + # Warmup run and correctness check with timeout |
| 163 | + warmup_task = @async begin |
| 164 | + try |
| 165 | + result = solve(prob, alg) |
| 166 | + put!(result_channel, result) |
| 167 | + catch e |
| 168 | + put!(result_channel, e) |
| 169 | + end |
| 170 | + end |
| 171 | + |
| 172 | + # Timer task to enforce timeout |
| 173 | + timer_task = @async begin |
| 174 | + sleep(maxtime) |
| 175 | + if !istaskdone(warmup_task) |
| 176 | + try |
| 177 | + Base.throwto(warmup_task, InterruptException()) |
| 178 | + catch |
| 179 | + # Task might be in non-interruptible state |
| 180 | + end |
| 181 | + put!(result_channel, :timeout) |
| 182 | + end |
| 183 | + end |
| 184 | + |
| 185 | + # Wait for result or timeout |
| 186 | + warmup_sol = nothing |
| 187 | + result = take!(result_channel) |
| 188 | + |
| 189 | + # Clean up timer task if still running |
| 190 | + if !istaskdone(timer_task) |
| 191 | + try |
| 192 | + Base.throwto(timer_task, InterruptException()) |
| 193 | + catch |
| 194 | + # Timer task might have already finished |
165 | 195 | end
|
166 | 196 | end
|
167 | 197 |
|
168 |
| - # Only benchmark if correctness check passed |
169 |
| - if passed_correctness |
170 |
| - # Actual benchmark |
171 |
| - bench = @benchmark solve($prob, $alg) setup=(prob = LinearProblem( |
172 |
| - copy($A), copy($b); |
173 |
| - u0 = copy($u0), |
174 |
| - alias = LinearAliasSpecifier(alias_A = true, alias_b = true))) |
175 |
| - |
176 |
| - # Calculate GFLOPs |
177 |
| - min_time_sec = minimum(bench.times) / 1e9 |
178 |
| - flops = luflop(n, n) |
179 |
| - gflops = flops / min_time_sec / 1e9 |
| 198 | + if result === :timeout |
| 199 | + # Task timed out |
| 200 | + timed_out = true |
| 201 | + @warn "Algorithm $name timed out (exceeded $(maxtime)s) for size $n, eltype $eltype. Recording as NaN." |
| 202 | + success = false |
| 203 | + error_msg = "Timed out (exceeded $(maxtime)s)" |
| 204 | + gflops = NaN |
| 205 | + elseif result isa Exception |
| 206 | + # Task threw an error |
| 207 | + throw(result) |
| 208 | + else |
| 209 | + # Successful completion |
| 210 | + warmup_sol = result |
| 211 | + elapsed_time = time() - start_time |
| 212 | + |
| 213 | + # Check correctness if reference solution is available |
| 214 | + if check_correctness && reference_solution !== nothing |
| 215 | + # Compute relative error |
| 216 | + rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u) |
| 217 | + |
| 218 | + if rel_error > correctness_tol |
| 219 | + passed_correctness = false |
| 220 | + @warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " * |
| 221 | + "Relative error: $(round(rel_error, sigdigits=3)) > tolerance: $correctness_tol. " * |
| 222 | + "Algorithm will be excluded from results." |
| 223 | + success = false |
| 224 | + error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))" |
| 225 | + gflops = 0.0 |
| 226 | + end |
| 227 | + end |
| 228 | + |
| 229 | + # Only benchmark if correctness check passed and we have time remaining |
| 230 | + if passed_correctness && !timed_out |
| 231 | + # Check if we have enough time remaining for benchmarking |
| 232 | + # Allow at least 2x the warmup time for benchmarking |
| 233 | + remaining_time = maxtime - elapsed_time |
| 234 | + if remaining_time < 2 * elapsed_time |
| 235 | + @warn "Algorithm $name: insufficient time remaining for benchmarking (warmup took $(round(elapsed_time, digits=2))s). Recording as NaN." |
| 236 | + gflops = NaN |
| 237 | + success = false |
| 238 | + error_msg = "Insufficient time for benchmarking" |
| 239 | + else |
| 240 | + # Actual benchmark |
| 241 | + bench = @benchmark solve($prob, $alg) setup=(prob = LinearProblem( |
| 242 | + copy($A), copy($b); |
| 243 | + u0 = copy($u0), |
| 244 | + alias = LinearAliasSpecifier(alias_A = true, alias_b = true))) |
| 245 | + |
| 246 | + # Calculate GFLOPs |
| 247 | + min_time_sec = minimum(bench.times) / 1e9 |
| 248 | + flops = luflop(n, n) |
| 249 | + gflops = flops / min_time_sec / 1e9 |
| 250 | + end |
| 251 | + end |
180 | 252 | end
|
181 | 253 |
|
182 | 254 | catch e
|
183 | 255 | success = false
|
184 | 256 | error_msg = string(e)
|
| 257 | + gflops = NaN |
185 | 258 | # Don't warn for each failure, just record it
|
186 | 259 | end
|
187 | 260 |
|
@@ -252,8 +325,8 @@ Categorize the benchmark results into size ranges and find the best algorithm fo
|
252 | 325 | For complex types, avoids RFLUFactorization if possible due to known issues.
|
253 | 326 | """
|
254 | 327 | function categorize_results(df::DataFrame)
|
255 |
| - # Filter successful results |
256 |
| - successful_df = filter(row -> row.success, df) |
| 328 | + # Filter successful results and exclude NaN values |
| 329 | + successful_df = filter(row -> row.success && !isnan(row.gflops), df) |
257 | 330 |
|
258 | 331 | if nrow(successful_df) == 0
|
259 | 332 | @warn "No successful benchmark results found!"
|
@@ -293,8 +366,9 @@ function categorize_results(df::DataFrame)
|
293 | 366 | continue
|
294 | 367 | end
|
295 | 368 |
|
296 |
| - # Calculate average GFLOPs for each algorithm in this range |
297 |
| - avg_results = combine(groupby(range_df, :algorithm), :gflops => mean => :avg_gflops) |
| 369 | + # Calculate average GFLOPs for each algorithm in this range, excluding NaN values |
| 370 | + avg_results = combine(groupby(range_df, :algorithm), |
| 371 | + :gflops => (x -> mean(filter(!isnan, x))) => :avg_gflops) |
298 | 372 |
|
299 | 373 | # Sort by performance
|
300 | 374 | sort!(avg_results, :avg_gflops, rev=true)
|
|
0 commit comments