|
73 | 73 |
|
74 | 74 | """
|
75 | 75 | benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
|
76 |
| - samples=5, seconds=0.5, sizes=[:small, :medium]) |
| 76 | + samples=5, seconds=0.5, sizes=[:small, :medium], |
| 77 | + maxtime=100.0) |
77 | 78 |
|
78 | 79 | Benchmark the given algorithms across different matrix sizes and element types.
|
79 | 80 | Returns a DataFrame with results including element type information.
|
| 81 | +
|
| 82 | +# Arguments |
| 83 | +- `maxtime::Float64 = 100.0`: Maximum time in seconds for each algorithm test (including accuracy check). |
| 84 | + If the accuracy check exceeds this time, the run is skipped and recorded as NaN. |
80 | 85 | """
|
81 | 86 | function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
|
82 | 87 | samples = 5, seconds = 0.5, sizes = [:tiny, :small, :medium, :large],
|
83 |
| - check_correctness = true, correctness_tol = 1e0) |
| 88 | + check_correctness = true, correctness_tol = 1e0, maxtime = 100.0) |
84 | 89 |
|
85 | 90 | # Set benchmark parameters
|
86 | 91 | old_params = BenchmarkTools.DEFAULT_PARAMETERS
|
@@ -136,52 +141,90 @@ function benchmark_algorithms(matrix_sizes, algorithms, alg_names, eltypes;
|
136 | 141 | ProgressMeter.update!(progress,
|
137 | 142 | desc="Benchmarking $name on $(n)×$(n) $eltype matrix: ")
|
138 | 143 |
|
139 |
| - gflops = 0.0 |
| 144 | + gflops = NaN # Use NaN for timed out runs |
140 | 145 | success = true
|
141 | 146 | error_msg = ""
|
142 | 147 | passed_correctness = true
|
| 148 | + timed_out = false |
143 | 149 |
|
144 | 150 | try
|
145 | 151 | # Create the linear problem for this test
|
146 | 152 | prob = LinearProblem(copy(A), copy(b);
|
147 | 153 | u0 = copy(u0),
|
148 | 154 | alias = LinearAliasSpecifier(alias_A = true, alias_b = true))
|
149 | 155 |
|
150 |
| - # Warmup run and correctness check |
151 |
| - warmup_sol = solve(prob, alg) |
| 156 | + # Time the warmup run and correctness check |
| 157 | + start_time = time() |
152 | 158 |
|
153 |
| - # Check correctness if reference solution is available |
154 |
| - if check_correctness && reference_solution !== nothing |
155 |
| - # Compute relative error |
156 |
| - rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u) |
157 |
| - |
158 |
| - if rel_error > correctness_tol |
159 |
| - passed_correctness = false |
160 |
| - @warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " * |
161 |
| - "Relative error: $(round(rel_error, sigdigits=3)) > tolerance: $correctness_tol. " * |
162 |
| - "Algorithm will be excluded from results." |
163 |
| - success = false |
164 |
| - error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))" |
165 |
| - end |
| 159 | + # Warmup run and correctness check with timeout |
| 160 | + warmup_task = @async begin |
| 161 | + solve(prob, alg) |
| 162 | + end |
| 163 | + |
| 164 | + # Wait for warmup to complete or timeout |
| 165 | + warmup_sol = nothing |
| 166 | + timeout_wait = maxtime |
| 167 | + while !istaskdone(warmup_task) && (time() - start_time) < timeout_wait |
| 168 | + sleep(0.1) |
166 | 169 | end
|
167 | 170 |
|
168 |
| - # Only benchmark if correctness check passed |
169 |
| - if passed_correctness |
170 |
| - # Actual benchmark |
171 |
| - bench = @benchmark solve($prob, $alg) setup=(prob = LinearProblem( |
172 |
| - copy($A), copy($b); |
173 |
| - u0 = copy($u0), |
174 |
| - alias = LinearAliasSpecifier(alias_A = true, alias_b = true))) |
175 |
| - |
176 |
| - # Calculate GFLOPs |
177 |
| - min_time_sec = minimum(bench.times) / 1e9 |
178 |
| - flops = luflop(n, n) |
179 |
| - gflops = flops / min_time_sec / 1e9 |
| 171 | + if !istaskdone(warmup_task) |
| 172 | + # Task timed out |
| 173 | + timed_out = true |
| 174 | + @warn "Algorithm $name timed out (exceeded $(maxtime)s) for size $n, eltype $eltype. Recording as NaN." |
| 175 | + success = false |
| 176 | + error_msg = "Timed out (exceeded $(maxtime)s)" |
| 177 | + gflops = NaN |
| 178 | + else |
| 179 | + # Get the result |
| 180 | + warmup_sol = fetch(warmup_task) |
| 181 | + elapsed_time = time() - start_time |
| 182 | + |
| 183 | + # Check correctness if reference solution is available |
| 184 | + if check_correctness && reference_solution !== nothing |
| 185 | + # Compute relative error |
| 186 | + rel_error = norm(warmup_sol.u - reference_solution.u) / norm(reference_solution.u) |
| 187 | + |
| 188 | + if rel_error > correctness_tol |
| 189 | + passed_correctness = false |
| 190 | + @warn "Algorithm $name failed correctness check for size $n, eltype $eltype. " * |
| 191 | + "Relative error: $(round(rel_error, sigdigits=3)) > tolerance: $correctness_tol. " * |
| 192 | + "Algorithm will be excluded from results." |
| 193 | + success = false |
| 194 | + error_msg = "Failed correctness check (rel_error = $(round(rel_error, sigdigits=3)))" |
| 195 | + gflops = 0.0 |
| 196 | + end |
| 197 | + end |
| 198 | + |
| 199 | + # Only benchmark if correctness check passed and we have time remaining |
| 200 | + if passed_correctness && !timed_out |
| 201 | + # Check if we have enough time remaining for benchmarking |
| 202 | + # Allow at least 2x the warmup time for benchmarking |
| 203 | + remaining_time = maxtime - elapsed_time |
| 204 | + if remaining_time < 2 * elapsed_time |
| 205 | + @warn "Algorithm $name: insufficient time remaining for benchmarking (warmup took $(round(elapsed_time, digits=2))s). Recording as NaN." |
| 206 | + gflops = NaN |
| 207 | + success = false |
| 208 | + error_msg = "Insufficient time for benchmarking" |
| 209 | + else |
| 210 | + # Actual benchmark |
| 211 | + bench = @benchmark solve($prob, $alg) setup=(prob = LinearProblem( |
| 212 | + copy($A), copy($b); |
| 213 | + u0 = copy($u0), |
| 214 | + alias = LinearAliasSpecifier(alias_A = true, alias_b = true))) |
| 215 | + |
| 216 | + # Calculate GFLOPs |
| 217 | + min_time_sec = minimum(bench.times) / 1e9 |
| 218 | + flops = luflop(n, n) |
| 219 | + gflops = flops / min_time_sec / 1e9 |
| 220 | + end |
| 221 | + end |
180 | 222 | end
|
181 | 223 |
|
182 | 224 | catch e
|
183 | 225 | success = false
|
184 | 226 | error_msg = string(e)
|
| 227 | + gflops = 0.0 |
185 | 228 | # Don't warn for each failure, just record it
|
186 | 229 | end
|
187 | 230 |
|
|
0 commit comments