|
28 | 28 |
|
29 | 29 | from collections import defaultdict |
30 | 30 |
|
31 | | -from ...data.raw_trace import GPUStreamID |
| 31 | +from ...data.raw_trace import GPUStreamID, GPUStageID |
32 | 32 | from ...data.processed_trace import GPUWorkload |
33 | 33 | from ...drawable.text_pane_widget import TextPaneWidget |
34 | 34 |
|
@@ -251,13 +251,71 @@ def get_active_region_stats(self): |
251 | 251 |
|
252 | 252 | return self.cached_range_info |
253 | 253 |
|
| 254 | + def compute_active_event_stats_single(self, event): |
| 255 | + ''' |
| 256 | + Get the metrics for the active event. |
| 257 | +
|
| 258 | + This function uses a cached lookup to avoid re-calculating every |
| 259 | + redraw, as the stats computation can be quite slow. |
| 260 | +
|
| 261 | + Args: |
| 262 | + event: The active event. |
| 263 | +
|
| 264 | + Returns: |
| 265 | + List of lines to be printed. |
| 266 | + ''' |
| 267 | + # Skip non-workload types |
| 268 | + if not isinstance(event, GPUWorkload): |
| 269 | + return [] |
| 270 | + |
| 271 | + stream_name = GPUStreamID.get_ui_name(event.stream) |
| 272 | + stage_name = GPUStageID.get_ui_name(event.stage) |
| 273 | + |
| 274 | + metrics = [''] |
| 275 | + |
| 276 | + # Report total runtime of the selected workloads |
| 277 | + other_names = [ |
| 278 | + 'API workloads:', |
| 279 | + 'Hardware workloads:' |
| 280 | + ] |
| 281 | + |
| 282 | + metrics.append('Active workload runtime:') |
| 283 | + label_len = len(stream_name) + len(' stream:') |
| 284 | + label_len = max(max(len(x) for x in other_names), label_len) |
| 285 | + |
| 286 | + label = other_names[0] |
| 287 | + metrics.append(f' {label:{label_len}} {1:>5}') |
| 288 | + |
| 289 | + label = other_names[1] |
| 290 | + metrics.append(f' {label:{label_len}} {1:>5}') |
| 291 | + |
| 292 | + label = f'{stream_name} stream:' |
| 293 | + duration = float(event.duration) / 1000000.0 |
| 294 | + metrics.append(f' {label:{label_len}} {duration:>5.2f} ms') |
| 295 | + |
| 296 | + # Report total N workloads |
| 297 | + metrics.append('') |
| 298 | + metrics.append('Workload properties:') |
| 299 | + |
| 300 | + label = event.get_workload_name() |
| 301 | + metrics.append(f' Name: {label}') |
| 302 | + metrics.append(f' Stream: {stream_name}') |
| 303 | + metrics.append(f' Stage: {stage_name}') |
| 304 | + metrics.append(f' Start: {event.start_time / 1000000.0:0.2f} ms') |
| 305 | + metrics.append(f' Duration: {event.duration / 1000000.0:0.2f} ms') |
| 306 | + metrics.append('') |
| 307 | + return metrics |
| 308 | + |
254 | 309 | def compute_active_event_stats_multi(self, active): |
255 | 310 | ''' |
256 | | - Get the metrics for the active time range. |
| 311 | + Get the metrics for the active events. |
257 | 312 |
|
258 | 313 | This function uses a cached lookup to avoid re-calculating every |
259 | 314 | redraw, as the stats computation can be quite slow. |
260 | 315 |
|
| 316 | + Args: |
| 317 | + active: List of active events. |
| 318 | +
|
261 | 319 | Returns: |
262 | 320 | List of lines to be printed. |
263 | 321 | ''' |
@@ -368,6 +426,9 @@ def get_active_event_stats(self): |
368 | 426 | if len(active) == 0: |
369 | 427 | info = None |
370 | 428 |
|
| 429 | + elif len(active) == 1: |
| 430 | + info = self.compute_active_event_stats_single(active[0]) |
| 431 | + |
371 | 432 | else: |
372 | 433 | info = self.compute_active_event_stats_multi(active) |
373 | 434 |
|
|
0 commit comments