2
2
3
3
import argparse
4
4
import shlex
5
- from typing import Dict , List , Tuple
5
+ from typing import Dict , List , Tuple , Optional , Any , Union
6
+ from dataclasses import dataclass
6
7
7
8
from .parser import get_parser
8
9
9
10
from .triton_op import BenchmarkOperatorResult , REGISTERED_X_VALS
10
11
11
12
13
+ @dataclass
14
+ class MetricComparison :
15
+ """Data class to store metric comparison results."""
16
+ val_a : float
17
+ val_b : float
18
+ improvement_pct : float
19
+ x_val : Any
20
+ backend : str
21
+ metric : str
22
+
23
+
24
+ def _extract_metric_value (metric_obj : Any ) -> Optional [float ]:
25
+ """Extract and normalize metric value from metric objects.
26
+
27
+ Only handles:
28
+ - Objects with p50 attribute (percentile-based metrics)
29
+ - Direct numeric values (int/float)
30
+
31
+ Skips all other types (tuples, complex objects, etc.)
32
+ """
33
+ if metric_obj is None :
34
+ return None
35
+
36
+ # Handle objects with p50 attribute (percentile-based metrics)
37
+ if hasattr (metric_obj , "p50" ):
38
+ return float (metric_obj .p50 )
39
+
40
+ # Handle direct numeric values only
41
+ if isinstance (metric_obj , (int , float )):
42
+ return float (metric_obj )
43
+
44
+ # Skip all other types (tuples, complex objects, etc.)
45
+ return None
46
+
47
+
48
+ def _calculate_improvement (val_a : float , val_b : float ) -> float :
49
+ """Calculate percentage improvement from val_a to val_b."""
50
+ if val_a == 0 :
51
+ return 0.0
52
+ return ((val_b - val_a ) / val_a ) * 100
53
+
54
+
55
+ def _get_comparable_data_points (
56
+ result_a : BenchmarkOperatorResult ,
57
+ result_b : BenchmarkOperatorResult ,
58
+ common_x_vals : List ,
59
+ common_backends : List [str ],
60
+ metric : str ,
61
+ ) -> List [MetricComparison ]:
62
+ """Get all comparable data points for a specific metric."""
63
+ # Create result dictionaries for easier lookup
64
+ result_dict_a = {x_val : metrics_dict for x_val , metrics_dict in result_a .result }
65
+ result_dict_b = {x_val : metrics_dict for x_val , metrics_dict in result_b .result }
66
+
67
+ comparisons = []
68
+
69
+ for backend in common_backends :
70
+ for x_val in common_x_vals :
71
+ if backend in result_dict_a [x_val ] and backend in result_dict_b [x_val ]:
72
+ metrics_a = result_dict_a [x_val ][backend ]
73
+ metrics_b = result_dict_b [x_val ][backend ]
74
+
75
+ # Try to get the metric from direct attribute first
76
+ raw_val_a = getattr (metrics_a , metric , None )
77
+ raw_val_b = getattr (metrics_b , metric , None )
78
+
79
+ # If not found, check in extra_metrics
80
+ if raw_val_a is None and hasattr (metrics_a , 'extra_metrics' ) and metrics_a .extra_metrics :
81
+ raw_val_a = metrics_a .extra_metrics .get (metric , None )
82
+ if raw_val_b is None and hasattr (metrics_b , 'extra_metrics' ) and metrics_b .extra_metrics :
83
+ raw_val_b = metrics_b .extra_metrics .get (metric , None )
84
+
85
+ val_a = _extract_metric_value (raw_val_a )
86
+ val_b = _extract_metric_value (raw_val_b )
87
+
88
+ if val_a is not None and val_b is not None :
89
+ improvement_pct = _calculate_improvement (val_a , val_b )
90
+ comparisons .append (MetricComparison (
91
+ val_a = val_a ,
92
+ val_b = val_b ,
93
+ improvement_pct = improvement_pct ,
94
+ x_val = x_val ,
95
+ backend = backend ,
96
+ metric = metric
97
+ ))
98
+
99
+ return comparisons
100
+
101
+
12
102
def parse_ab_config (config_str : str ) -> List [str ]:
13
103
"""Parse A/B configuration string into argument list."""
14
104
if not config_str :
@@ -163,48 +253,37 @@ def parse_config_to_dict(args):
163
253
return differences
164
254
165
255
166
- def _calculate_performance_summary (
256
+ def _get_all_comparable_data_points (
167
257
result_a : BenchmarkOperatorResult ,
168
258
result_b : BenchmarkOperatorResult ,
169
259
common_x_vals : List ,
170
260
common_backends : List [str ],
261
+ ) -> Dict [str , List [MetricComparison ]]:
262
+ """Get all comparable data points for all metrics at once."""
263
+ all_comparisons = {}
264
+
265
+ for metric in result_a .metrics :
266
+ all_comparisons [metric ] = _get_comparable_data_points (
267
+ result_a , result_b , common_x_vals , common_backends , metric
268
+ )
269
+
270
+ return all_comparisons
271
+
272
+
273
+ def _calculate_performance_summary (
274
+ all_comparisons : Dict [str , List [MetricComparison ]],
275
+ common_backends : List [str ],
171
276
) -> Dict [str , Dict [str , float ]]:
172
- """Calculate performance summary statistics."""
277
+ """Calculate performance summary statistics from pre-computed comparisons ."""
173
278
summary = {}
174
279
175
- # Create result dictionaries for easier lookup
176
- result_dict_a = {x_val : metrics_dict for x_val , metrics_dict in result_a .result }
177
- result_dict_b = {x_val : metrics_dict for x_val , metrics_dict in result_b .result }
178
-
179
280
for backend in common_backends :
180
281
backend_summary = {}
181
282
182
- for metric in result_a .metrics :
183
- improvements = []
184
-
185
- for x_val in common_x_vals :
186
- if backend in result_dict_a [x_val ] and backend in result_dict_b [x_val ]:
187
- metrics_a = result_dict_a [x_val ][backend ]
188
- metrics_b = result_dict_b [x_val ][backend ]
189
-
190
- val_a = getattr (metrics_a , metric , None )
191
- val_b = getattr (metrics_b , metric , None )
192
-
193
- if val_a is not None and val_b is not None :
194
- # Handle different metric types
195
- if hasattr (val_a , "p50" ):
196
- val_a_num = val_a .p50
197
- else :
198
- val_a_num = val_a
199
-
200
- if hasattr (val_b , "p50" ):
201
- val_b_num = val_b .p50
202
- else :
203
- val_b_num = val_b
204
-
205
- if val_a_num != 0 :
206
- improvement = ((val_b_num - val_a_num ) / val_a_num ) * 100
207
- improvements .append (improvement )
283
+ for metric , comparisons in all_comparisons .items ():
284
+ # Filter for current backend
285
+ backend_comparisons = [c for c in comparisons if c .backend == backend ]
286
+ improvements = [c .improvement_pct for c in backend_comparisons ]
208
287
209
288
if improvements :
210
289
backend_summary [metric ] = {
@@ -259,6 +338,14 @@ def compare_ab_results(
259
338
print ("ERROR: No common backends found between configurations" )
260
339
return
261
340
341
+ # ============================================================================
342
+ # PRE-COMPUTE: Get all comparable data points once
343
+ # ============================================================================
344
+ all_comparisons = _get_all_comparable_data_points (
345
+ result_a , result_b , common_x_vals , common_backends
346
+ )
347
+
348
+
262
349
# ============================================================================
263
350
# SECTION 1: Configuration Analysis
264
351
# ============================================================================
@@ -290,9 +377,7 @@ def compare_ab_results(
290
377
print ("Performance Summary" )
291
378
print ("-" * 70 )
292
379
293
- summary = _calculate_performance_summary (
294
- result_a , result_b , common_x_vals , common_backends
295
- )
380
+ summary = _calculate_performance_summary (all_comparisons , common_backends )
296
381
297
382
for backend in common_backends :
298
383
print (f"\n { backend } :" )
@@ -320,8 +405,10 @@ def compare_ab_results(
320
405
321
406
x_val_name = REGISTERED_X_VALS .get (result_a .op_name , "x_val" )
322
407
323
- # Show all metrics for detailed comparison
408
+ # Show only metrics that have comparable data
324
409
for metric in result_a .metrics :
410
+ if metric not in all_comparisons or len (all_comparisons [metric ]) == 0 :
411
+ continue # Skip metrics with no comparable data
325
412
print (f"\n Metric: { metric } " )
326
413
print ("Backend" .ljust (15 ), end = "" )
327
414
print (x_val_name .ljust (20 ), end = "" )
@@ -330,49 +417,36 @@ def compare_ab_results(
330
417
print ("Difference" .ljust (12 ))
331
418
print ("-" * 71 )
332
419
333
- for backend in common_backends :
334
- first_row = True
335
- for x_val in common_x_vals :
336
- if (
337
- backend not in result_dict_a [x_val ]
338
- or backend not in result_dict_b [x_val ]
339
- ):
340
- continue
420
+ # Use pre-computed comparisons
421
+ comparisons = all_comparisons .get (metric , [])
341
422
342
- metrics_a = result_dict_a [x_val ][backend ]
343
- metrics_b = result_dict_b [x_val ][backend ]
344
-
345
- val_a = getattr (metrics_a , metric , None )
346
- val_b = getattr (metrics_b , metric , None )
347
-
348
- if val_a is not None and val_b is not None :
349
- # Handle different data types
350
- if hasattr (val_a , "p50" ):
351
- val_a_num = val_a .p50
352
- val_b_num = val_b .p50
353
- else :
354
- val_a_num = val_a
355
- val_b_num = val_b
423
+ # Group by backend for display
424
+ backend_comparisons = {}
425
+ for comp in comparisons :
426
+ if comp .backend not in backend_comparisons :
427
+ backend_comparisons [comp .backend ] = []
428
+ backend_comparisons [comp .backend ].append (comp )
356
429
357
- if val_a_num != 0 :
358
- diff_pct = ((val_b_num - val_a_num ) / val_a_num ) * 100
359
- else :
360
- diff_pct = 0
361
-
362
- # Format values
363
- if isinstance (val_a_num , float ):
364
- val_a_str = f"{ val_a_num :.3f} "
365
- val_b_str = f"{ val_b_num :.3f} "
366
- else :
367
- val_a_str = str (val_a_num )
368
- val_b_str = str (val_b_num )
369
-
370
- # Print row
371
- backend_name = backend if first_row else ""
372
- print (
373
- f"{ backend_name :<15} { str (x_val ):<20} { val_a_str :<12} { val_b_str :<12} { diff_pct :+5.1f} %"
374
- )
375
- first_row = False
430
+ for backend in common_backends :
431
+ if backend not in backend_comparisons :
432
+ continue
433
+
434
+ first_row = True
435
+ for comp in backend_comparisons [backend ]:
436
+ # Format values
437
+ if isinstance (comp .val_a , float ):
438
+ val_a_str = f"{ comp .val_a :.3f} "
439
+ val_b_str = f"{ comp .val_b :.3f} "
440
+ else :
441
+ val_a_str = str (comp .val_a )
442
+ val_b_str = str (comp .val_b )
443
+
444
+ # Print row
445
+ backend_name = backend if first_row else ""
446
+ print (
447
+ f"{ backend_name :<15} { str (comp .x_val ):<20} { val_a_str :<12} { val_b_str :<12} { comp .improvement_pct :+5.1f} %"
448
+ )
449
+ first_row = False
376
450
377
451
if not first_row : # Only print separator if we printed data
378
452
print ()
0 commit comments