2
2
3
3
import argparse
4
4
import shlex
5
- from typing import Dict , List , Tuple , Optional , Any , Union
6
- from dataclasses import dataclass
5
+ from typing import Dict , List , Tuple
7
6
8
7
from .parser import get_parser
9
8
10
9
from .triton_op import BenchmarkOperatorResult , REGISTERED_X_VALS
11
10
12
11
13
- @dataclass
14
- class MetricComparison :
15
- """Data class to store metric comparison results."""
16
- val_a : float
17
- val_b : float
18
- improvement_pct : float
19
- x_val : Any
20
- backend : str
21
- metric : str
22
-
23
-
24
- def _extract_metric_value (metric_obj : Any ) -> Optional [float ]:
25
- """Extract and normalize metric value from metric objects.
26
-
27
- Only handles:
28
- - Objects with p50 attribute (percentile-based metrics)
29
- - Direct numeric values (int/float)
30
-
31
- Skips all other types (tuples, complex objects, etc.)
32
- """
33
- if metric_obj is None :
34
- return None
35
-
36
- # Handle objects with p50 attribute (percentile-based metrics)
37
- if hasattr (metric_obj , "p50" ):
38
- return float (metric_obj .p50 )
39
-
40
- # Handle direct numeric values only
41
- if isinstance (metric_obj , (int , float )):
42
- return float (metric_obj )
43
-
44
- # Skip all other types (tuples, complex objects, etc.)
45
- return None
46
-
47
-
48
- def _calculate_improvement (val_a : float , val_b : float ) -> float :
49
- """Calculate percentage improvement from val_a to val_b."""
50
- if val_a == 0 :
51
- return 0.0
52
- return ((val_b - val_a ) / val_a ) * 100
53
-
54
-
55
- def _get_comparable_data_points (
56
- result_a : BenchmarkOperatorResult ,
57
- result_b : BenchmarkOperatorResult ,
58
- common_x_vals : List ,
59
- common_backends : List [str ],
60
- metric : str ,
61
- ) -> List [MetricComparison ]:
62
- """Get all comparable data points for a specific metric."""
63
- # Create result dictionaries for easier lookup
64
- result_dict_a = {x_val : metrics_dict for x_val , metrics_dict in result_a .result }
65
- result_dict_b = {x_val : metrics_dict for x_val , metrics_dict in result_b .result }
66
-
67
- comparisons = []
68
-
69
- for backend in common_backends :
70
- for x_val in common_x_vals :
71
- if backend in result_dict_a [x_val ] and backend in result_dict_b [x_val ]:
72
- metrics_a = result_dict_a [x_val ][backend ]
73
- metrics_b = result_dict_b [x_val ][backend ]
74
-
75
- # Try to get the metric from direct attribute first
76
- raw_val_a = getattr (metrics_a , metric , None )
77
- raw_val_b = getattr (metrics_b , metric , None )
78
-
79
- # If not found, check in extra_metrics
80
- if raw_val_a is None and hasattr (metrics_a , 'extra_metrics' ) and metrics_a .extra_metrics :
81
- raw_val_a = metrics_a .extra_metrics .get (metric , None )
82
- if raw_val_b is None and hasattr (metrics_b , 'extra_metrics' ) and metrics_b .extra_metrics :
83
- raw_val_b = metrics_b .extra_metrics .get (metric , None )
84
-
85
- val_a = _extract_metric_value (raw_val_a )
86
- val_b = _extract_metric_value (raw_val_b )
87
-
88
- if val_a is not None and val_b is not None :
89
- improvement_pct = _calculate_improvement (val_a , val_b )
90
- comparisons .append (MetricComparison (
91
- val_a = val_a ,
92
- val_b = val_b ,
93
- improvement_pct = improvement_pct ,
94
- x_val = x_val ,
95
- backend = backend ,
96
- metric = metric
97
- ))
98
-
99
- return comparisons
100
-
101
-
102
12
def parse_ab_config (config_str : str ) -> List [str ]:
103
13
"""Parse A/B configuration string into argument list."""
104
14
if not config_str :
@@ -253,37 +163,48 @@ def parse_config_to_dict(args):
253
163
return differences
254
164
255
165
256
- def _get_all_comparable_data_points (
166
+ def _calculate_performance_summary (
257
167
result_a : BenchmarkOperatorResult ,
258
168
result_b : BenchmarkOperatorResult ,
259
169
common_x_vals : List ,
260
170
common_backends : List [str ],
261
- ) -> Dict [str , List [MetricComparison ]]:
262
- """Get all comparable data points for all metrics at once."""
263
- all_comparisons = {}
264
-
265
- for metric in result_a .metrics :
266
- all_comparisons [metric ] = _get_comparable_data_points (
267
- result_a , result_b , common_x_vals , common_backends , metric
268
- )
269
-
270
- return all_comparisons
271
-
272
-
273
- def _calculate_performance_summary (
274
- all_comparisons : Dict [str , List [MetricComparison ]],
275
- common_backends : List [str ],
276
171
) -> Dict [str , Dict [str , float ]]:
277
- """Calculate performance summary statistics from pre-computed comparisons ."""
172
+ """Calculate performance summary statistics."""
278
173
summary = {}
279
174
175
+ # Create result dictionaries for easier lookup
176
+ result_dict_a = {x_val : metrics_dict for x_val , metrics_dict in result_a .result }
177
+ result_dict_b = {x_val : metrics_dict for x_val , metrics_dict in result_b .result }
178
+
280
179
for backend in common_backends :
281
180
backend_summary = {}
282
181
283
- for metric , comparisons in all_comparisons .items ():
284
- # Filter for current backend
285
- backend_comparisons = [c for c in comparisons if c .backend == backend ]
286
- improvements = [c .improvement_pct for c in backend_comparisons ]
182
+ for metric in result_a .metrics :
183
+ improvements = []
184
+
185
+ for x_val in common_x_vals :
186
+ if backend in result_dict_a [x_val ] and backend in result_dict_b [x_val ]:
187
+ metrics_a = result_dict_a [x_val ][backend ]
188
+ metrics_b = result_dict_b [x_val ][backend ]
189
+
190
+ val_a = getattr (metrics_a , metric , None )
191
+ val_b = getattr (metrics_b , metric , None )
192
+
193
+ if val_a is not None and val_b is not None :
194
+ # Handle different metric types
195
+ if hasattr (val_a , "p50" ):
196
+ val_a_num = val_a .p50
197
+ else :
198
+ val_a_num = val_a
199
+
200
+ if hasattr (val_b , "p50" ):
201
+ val_b_num = val_b .p50
202
+ else :
203
+ val_b_num = val_b
204
+
205
+ if val_a_num != 0 :
206
+ improvement = ((val_b_num - val_a_num ) / val_a_num ) * 100
207
+ improvements .append (improvement )
287
208
288
209
if improvements :
289
210
backend_summary [metric ] = {
@@ -338,14 +259,6 @@ def compare_ab_results(
338
259
print ("ERROR: No common backends found between configurations" )
339
260
return
340
261
341
- # ============================================================================
342
- # PRE-COMPUTE: Get all comparable data points once
343
- # ============================================================================
344
- all_comparisons = _get_all_comparable_data_points (
345
- result_a , result_b , common_x_vals , common_backends
346
- )
347
-
348
-
349
262
# ============================================================================
350
263
# SECTION 1: Configuration Analysis
351
264
# ============================================================================
@@ -377,7 +290,9 @@ def compare_ab_results(
377
290
print ("Performance Summary" )
378
291
print ("-" * 70 )
379
292
380
- summary = _calculate_performance_summary (all_comparisons , common_backends )
293
+ summary = _calculate_performance_summary (
294
+ result_a , result_b , common_x_vals , common_backends
295
+ )
381
296
382
297
for backend in common_backends :
383
298
print (f"\n { backend } :" )
@@ -405,10 +320,8 @@ def compare_ab_results(
405
320
406
321
x_val_name = REGISTERED_X_VALS .get (result_a .op_name , "x_val" )
407
322
408
- # Show only metrics that have comparable data
323
+ # Show all metrics for detailed comparison
409
324
for metric in result_a .metrics :
410
- if metric not in all_comparisons or len (all_comparisons [metric ]) == 0 :
411
- continue # Skip metrics with no comparable data
412
325
print (f"\n Metric: { metric } " )
413
326
print ("Backend" .ljust (15 ), end = "" )
414
327
print (x_val_name .ljust (20 ), end = "" )
@@ -417,36 +330,49 @@ def compare_ab_results(
417
330
print ("Difference" .ljust (12 ))
418
331
print ("-" * 71 )
419
332
420
- # Use pre-computed comparisons
421
- comparisons = all_comparisons .get (metric , [])
422
-
423
- # Group by backend for display
424
- backend_comparisons = {}
425
- for comp in comparisons :
426
- if comp .backend not in backend_comparisons :
427
- backend_comparisons [comp .backend ] = []
428
- backend_comparisons [comp .backend ].append (comp )
429
-
430
333
for backend in common_backends :
431
- if backend not in backend_comparisons :
432
- continue
433
-
434
334
first_row = True
435
- for comp in backend_comparisons [backend ]:
436
- # Format values
437
- if isinstance (comp .val_a , float ):
438
- val_a_str = f"{ comp .val_a :.3f} "
439
- val_b_str = f"{ comp .val_b :.3f} "
440
- else :
441
- val_a_str = str (comp .val_a )
442
- val_b_str = str (comp .val_b )
443
-
444
- # Print row
445
- backend_name = backend if first_row else ""
446
- print (
447
- f"{ backend_name :<15} { str (comp .x_val ):<20} { val_a_str :<12} { val_b_str :<12} { comp .improvement_pct :+5.1f} %"
448
- )
449
- first_row = False
335
+ for x_val in common_x_vals :
336
+ if (
337
+ backend not in result_dict_a [x_val ]
338
+ or backend not in result_dict_b [x_val ]
339
+ ):
340
+ continue
341
+
342
+ metrics_a = result_dict_a [x_val ][backend ]
343
+ metrics_b = result_dict_b [x_val ][backend ]
344
+
345
+ val_a = getattr (metrics_a , metric , None )
346
+ val_b = getattr (metrics_b , metric , None )
347
+
348
+ if val_a is not None and val_b is not None :
349
+ # Handle different data types
350
+ if hasattr (val_a , "p50" ):
351
+ val_a_num = val_a .p50
352
+ val_b_num = val_b .p50
353
+ else :
354
+ val_a_num = val_a
355
+ val_b_num = val_b
356
+
357
+ if val_a_num != 0 :
358
+ diff_pct = ((val_b_num - val_a_num ) / val_a_num ) * 100
359
+ else :
360
+ diff_pct = 0
361
+
362
+ # Format values
363
+ if isinstance (val_a_num , float ):
364
+ val_a_str = f"{ val_a_num :.3f} "
365
+ val_b_str = f"{ val_b_num :.3f} "
366
+ else :
367
+ val_a_str = str (val_a_num )
368
+ val_b_str = str (val_b_num )
369
+
370
+ # Print row
371
+ backend_name = backend if first_row else ""
372
+ print (
373
+ f"{ backend_name :<15} { str (x_val ):<20} { val_a_str :<12} { val_b_str :<12} { diff_pct :+5.1f} %"
374
+ )
375
+ first_row = False
450
376
451
377
if not first_row : # Only print separator if we printed data
452
378
print ()
0 commit comments