Skip to content

Commit 35cf1ba

Browse files
committed
Fix A/B testing for metrics with complex types
A/B test was showing empty results for some metrics like gbps that return tuples instead of simple values. Now we only compare compatible metric types (floats and objects with p50 attributes) and skip the rest. Also cleaned up some duplicate code while we were at it.
1 parent c378245 commit 35cf1ba

File tree

1 file changed

+152
-78
lines changed

1 file changed

+152
-78
lines changed

tritonbench/utils/ab_test.py

Lines changed: 152 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,103 @@
22

33
import argparse
44
import shlex
5-
from typing import Dict, List, Tuple
5+
from typing import Dict, List, Tuple, Optional, Any, Union
6+
from dataclasses import dataclass
67

78
from .parser import get_parser
89

910
from .triton_op import BenchmarkOperatorResult, REGISTERED_X_VALS
1011

1112

13+
@dataclass
14+
class MetricComparison:
15+
"""Data class to store metric comparison results."""
16+
val_a: float
17+
val_b: float
18+
improvement_pct: float
19+
x_val: Any
20+
backend: str
21+
metric: str
22+
23+
24+
def _extract_metric_value(metric_obj: Any) -> Optional[float]:
25+
"""Extract and normalize metric value from metric objects.
26+
27+
Only handles:
28+
- Objects with p50 attribute (percentile-based metrics)
29+
- Direct numeric values (int/float)
30+
31+
Skips all other types (tuples, complex objects, etc.)
32+
"""
33+
if metric_obj is None:
34+
return None
35+
36+
# Handle objects with p50 attribute (percentile-based metrics)
37+
if hasattr(metric_obj, "p50"):
38+
return float(metric_obj.p50)
39+
40+
# Handle direct numeric values only
41+
if isinstance(metric_obj, (int, float)):
42+
return float(metric_obj)
43+
44+
# Skip all other types (tuples, complex objects, etc.)
45+
return None
46+
47+
48+
def _calculate_improvement(val_a: float, val_b: float) -> float:
49+
"""Calculate percentage improvement from val_a to val_b."""
50+
if val_a == 0:
51+
return 0.0
52+
return ((val_b - val_a) / val_a) * 100
53+
54+
55+
def _get_comparable_data_points(
56+
result_a: BenchmarkOperatorResult,
57+
result_b: BenchmarkOperatorResult,
58+
common_x_vals: List,
59+
common_backends: List[str],
60+
metric: str,
61+
) -> List[MetricComparison]:
62+
"""Get all comparable data points for a specific metric."""
63+
# Create result dictionaries for easier lookup
64+
result_dict_a = {x_val: metrics_dict for x_val, metrics_dict in result_a.result}
65+
result_dict_b = {x_val: metrics_dict for x_val, metrics_dict in result_b.result}
66+
67+
comparisons = []
68+
69+
for backend in common_backends:
70+
for x_val in common_x_vals:
71+
if backend in result_dict_a[x_val] and backend in result_dict_b[x_val]:
72+
metrics_a = result_dict_a[x_val][backend]
73+
metrics_b = result_dict_b[x_val][backend]
74+
75+
# Try to get the metric from direct attribute first
76+
raw_val_a = getattr(metrics_a, metric, None)
77+
raw_val_b = getattr(metrics_b, metric, None)
78+
79+
# If not found, check in extra_metrics
80+
if raw_val_a is None and hasattr(metrics_a, 'extra_metrics') and metrics_a.extra_metrics:
81+
raw_val_a = metrics_a.extra_metrics.get(metric, None)
82+
if raw_val_b is None and hasattr(metrics_b, 'extra_metrics') and metrics_b.extra_metrics:
83+
raw_val_b = metrics_b.extra_metrics.get(metric, None)
84+
85+
val_a = _extract_metric_value(raw_val_a)
86+
val_b = _extract_metric_value(raw_val_b)
87+
88+
if val_a is not None and val_b is not None:
89+
improvement_pct = _calculate_improvement(val_a, val_b)
90+
comparisons.append(MetricComparison(
91+
val_a=val_a,
92+
val_b=val_b,
93+
improvement_pct=improvement_pct,
94+
x_val=x_val,
95+
backend=backend,
96+
metric=metric
97+
))
98+
99+
return comparisons
100+
101+
12102
def parse_ab_config(config_str: str) -> List[str]:
13103
"""Parse A/B configuration string into argument list."""
14104
if not config_str:
@@ -163,48 +253,37 @@ def parse_config_to_dict(args):
163253
return differences
164254

165255

166-
def _calculate_performance_summary(
256+
def _get_all_comparable_data_points(
167257
result_a: BenchmarkOperatorResult,
168258
result_b: BenchmarkOperatorResult,
169259
common_x_vals: List,
170260
common_backends: List[str],
261+
) -> Dict[str, List[MetricComparison]]:
262+
"""Get all comparable data points for all metrics at once."""
263+
all_comparisons = {}
264+
265+
for metric in result_a.metrics:
266+
all_comparisons[metric] = _get_comparable_data_points(
267+
result_a, result_b, common_x_vals, common_backends, metric
268+
)
269+
270+
return all_comparisons
271+
272+
273+
def _calculate_performance_summary(
274+
all_comparisons: Dict[str, List[MetricComparison]],
275+
common_backends: List[str],
171276
) -> Dict[str, Dict[str, float]]:
172-
"""Calculate performance summary statistics."""
277+
"""Calculate performance summary statistics from pre-computed comparisons."""
173278
summary = {}
174279

175-
# Create result dictionaries for easier lookup
176-
result_dict_a = {x_val: metrics_dict for x_val, metrics_dict in result_a.result}
177-
result_dict_b = {x_val: metrics_dict for x_val, metrics_dict in result_b.result}
178-
179280
for backend in common_backends:
180281
backend_summary = {}
181282

182-
for metric in result_a.metrics:
183-
improvements = []
184-
185-
for x_val in common_x_vals:
186-
if backend in result_dict_a[x_val] and backend in result_dict_b[x_val]:
187-
metrics_a = result_dict_a[x_val][backend]
188-
metrics_b = result_dict_b[x_val][backend]
189-
190-
val_a = getattr(metrics_a, metric, None)
191-
val_b = getattr(metrics_b, metric, None)
192-
193-
if val_a is not None and val_b is not None:
194-
# Handle different metric types
195-
if hasattr(val_a, "p50"):
196-
val_a_num = val_a.p50
197-
else:
198-
val_a_num = val_a
199-
200-
if hasattr(val_b, "p50"):
201-
val_b_num = val_b.p50
202-
else:
203-
val_b_num = val_b
204-
205-
if val_a_num != 0:
206-
improvement = ((val_b_num - val_a_num) / val_a_num) * 100
207-
improvements.append(improvement)
283+
for metric, comparisons in all_comparisons.items():
284+
# Filter for current backend
285+
backend_comparisons = [c for c in comparisons if c.backend == backend]
286+
improvements = [c.improvement_pct for c in backend_comparisons]
208287

209288
if improvements:
210289
backend_summary[metric] = {
@@ -259,6 +338,14 @@ def compare_ab_results(
259338
print("ERROR: No common backends found between configurations")
260339
return
261340

341+
# ============================================================================
342+
# PRE-COMPUTE: Get all comparable data points once
343+
# ============================================================================
344+
all_comparisons = _get_all_comparable_data_points(
345+
result_a, result_b, common_x_vals, common_backends
346+
)
347+
348+
262349
# ============================================================================
263350
# SECTION 1: Configuration Analysis
264351
# ============================================================================
@@ -290,9 +377,7 @@ def compare_ab_results(
290377
print("Performance Summary")
291378
print("-" * 70)
292379

293-
summary = _calculate_performance_summary(
294-
result_a, result_b, common_x_vals, common_backends
295-
)
380+
summary = _calculate_performance_summary(all_comparisons, common_backends)
296381

297382
for backend in common_backends:
298383
print(f"\n{backend}:")
@@ -320,8 +405,10 @@ def compare_ab_results(
320405

321406
x_val_name = REGISTERED_X_VALS.get(result_a.op_name, "x_val")
322407

323-
# Show all metrics for detailed comparison
408+
# Show only metrics that have comparable data
324409
for metric in result_a.metrics:
410+
if metric not in all_comparisons or len(all_comparisons[metric]) == 0:
411+
continue # Skip metrics with no comparable data
325412
print(f"\nMetric: {metric}")
326413
print("Backend".ljust(15), end="")
327414
print(x_val_name.ljust(20), end="")
@@ -330,49 +417,36 @@ def compare_ab_results(
330417
print("Difference".ljust(12))
331418
print("-" * 71)
332419

333-
for backend in common_backends:
334-
first_row = True
335-
for x_val in common_x_vals:
336-
if (
337-
backend not in result_dict_a[x_val]
338-
or backend not in result_dict_b[x_val]
339-
):
340-
continue
420+
# Use pre-computed comparisons
421+
comparisons = all_comparisons.get(metric, [])
341422

342-
metrics_a = result_dict_a[x_val][backend]
343-
metrics_b = result_dict_b[x_val][backend]
344-
345-
val_a = getattr(metrics_a, metric, None)
346-
val_b = getattr(metrics_b, metric, None)
347-
348-
if val_a is not None and val_b is not None:
349-
# Handle different data types
350-
if hasattr(val_a, "p50"):
351-
val_a_num = val_a.p50
352-
val_b_num = val_b.p50
353-
else:
354-
val_a_num = val_a
355-
val_b_num = val_b
423+
# Group by backend for display
424+
backend_comparisons = {}
425+
for comp in comparisons:
426+
if comp.backend not in backend_comparisons:
427+
backend_comparisons[comp.backend] = []
428+
backend_comparisons[comp.backend].append(comp)
356429

357-
if val_a_num != 0:
358-
diff_pct = ((val_b_num - val_a_num) / val_a_num) * 100
359-
else:
360-
diff_pct = 0
361-
362-
# Format values
363-
if isinstance(val_a_num, float):
364-
val_a_str = f"{val_a_num:.3f}"
365-
val_b_str = f"{val_b_num:.3f}"
366-
else:
367-
val_a_str = str(val_a_num)
368-
val_b_str = str(val_b_num)
369-
370-
# Print row
371-
backend_name = backend if first_row else ""
372-
print(
373-
f"{backend_name:<15}{str(x_val):<20}{val_a_str:<12}{val_b_str:<12}{diff_pct:+5.1f}%"
374-
)
375-
first_row = False
430+
for backend in common_backends:
431+
if backend not in backend_comparisons:
432+
continue
433+
434+
first_row = True
435+
for comp in backend_comparisons[backend]:
436+
# Format values
437+
if isinstance(comp.val_a, float):
438+
val_a_str = f"{comp.val_a:.3f}"
439+
val_b_str = f"{comp.val_b:.3f}"
440+
else:
441+
val_a_str = str(comp.val_a)
442+
val_b_str = str(comp.val_b)
443+
444+
# Print row
445+
backend_name = backend if first_row else ""
446+
print(
447+
f"{backend_name:<15}{str(comp.x_val):<20}{val_a_str:<12}{val_b_str:<12}{comp.improvement_pct:+5.1f}%"
448+
)
449+
first_row = False
376450

377451
if not first_row: # Only print separator if we printed data
378452
print()

0 commit comments

Comments
 (0)