Skip to content

Commit a4c99bd

Browse files
committed
fix lint
1 parent 35cf1ba commit a4c99bd

File tree

1 file changed

+78
-152
lines changed

1 file changed

+78
-152
lines changed

tritonbench/utils/ab_test.py

Lines changed: 78 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -2,103 +2,13 @@
22

33
import argparse
44
import shlex
5-
from typing import Dict, List, Tuple, Optional, Any, Union
6-
from dataclasses import dataclass
5+
from typing import Dict, List, Tuple
76

87
from .parser import get_parser
98

109
from .triton_op import BenchmarkOperatorResult, REGISTERED_X_VALS
1110

1211

13-
@dataclass
14-
class MetricComparison:
15-
"""Data class to store metric comparison results."""
16-
val_a: float
17-
val_b: float
18-
improvement_pct: float
19-
x_val: Any
20-
backend: str
21-
metric: str
22-
23-
24-
def _extract_metric_value(metric_obj: Any) -> Optional[float]:
25-
"""Extract and normalize metric value from metric objects.
26-
27-
Only handles:
28-
- Objects with p50 attribute (percentile-based metrics)
29-
- Direct numeric values (int/float)
30-
31-
Skips all other types (tuples, complex objects, etc.)
32-
"""
33-
if metric_obj is None:
34-
return None
35-
36-
# Handle objects with p50 attribute (percentile-based metrics)
37-
if hasattr(metric_obj, "p50"):
38-
return float(metric_obj.p50)
39-
40-
# Handle direct numeric values only
41-
if isinstance(metric_obj, (int, float)):
42-
return float(metric_obj)
43-
44-
# Skip all other types (tuples, complex objects, etc.)
45-
return None
46-
47-
48-
def _calculate_improvement(val_a: float, val_b: float) -> float:
49-
"""Calculate percentage improvement from val_a to val_b."""
50-
if val_a == 0:
51-
return 0.0
52-
return ((val_b - val_a) / val_a) * 100
53-
54-
55-
def _get_comparable_data_points(
56-
result_a: BenchmarkOperatorResult,
57-
result_b: BenchmarkOperatorResult,
58-
common_x_vals: List,
59-
common_backends: List[str],
60-
metric: str,
61-
) -> List[MetricComparison]:
62-
"""Get all comparable data points for a specific metric."""
63-
# Create result dictionaries for easier lookup
64-
result_dict_a = {x_val: metrics_dict for x_val, metrics_dict in result_a.result}
65-
result_dict_b = {x_val: metrics_dict for x_val, metrics_dict in result_b.result}
66-
67-
comparisons = []
68-
69-
for backend in common_backends:
70-
for x_val in common_x_vals:
71-
if backend in result_dict_a[x_val] and backend in result_dict_b[x_val]:
72-
metrics_a = result_dict_a[x_val][backend]
73-
metrics_b = result_dict_b[x_val][backend]
74-
75-
# Try to get the metric from direct attribute first
76-
raw_val_a = getattr(metrics_a, metric, None)
77-
raw_val_b = getattr(metrics_b, metric, None)
78-
79-
# If not found, check in extra_metrics
80-
if raw_val_a is None and hasattr(metrics_a, 'extra_metrics') and metrics_a.extra_metrics:
81-
raw_val_a = metrics_a.extra_metrics.get(metric, None)
82-
if raw_val_b is None and hasattr(metrics_b, 'extra_metrics') and metrics_b.extra_metrics:
83-
raw_val_b = metrics_b.extra_metrics.get(metric, None)
84-
85-
val_a = _extract_metric_value(raw_val_a)
86-
val_b = _extract_metric_value(raw_val_b)
87-
88-
if val_a is not None and val_b is not None:
89-
improvement_pct = _calculate_improvement(val_a, val_b)
90-
comparisons.append(MetricComparison(
91-
val_a=val_a,
92-
val_b=val_b,
93-
improvement_pct=improvement_pct,
94-
x_val=x_val,
95-
backend=backend,
96-
metric=metric
97-
))
98-
99-
return comparisons
100-
101-
10212
def parse_ab_config(config_str: str) -> List[str]:
10313
"""Parse A/B configuration string into argument list."""
10414
if not config_str:
@@ -253,37 +163,48 @@ def parse_config_to_dict(args):
253163
return differences
254164

255165

256-
def _get_all_comparable_data_points(
166+
def _calculate_performance_summary(
257167
result_a: BenchmarkOperatorResult,
258168
result_b: BenchmarkOperatorResult,
259169
common_x_vals: List,
260170
common_backends: List[str],
261-
) -> Dict[str, List[MetricComparison]]:
262-
"""Get all comparable data points for all metrics at once."""
263-
all_comparisons = {}
264-
265-
for metric in result_a.metrics:
266-
all_comparisons[metric] = _get_comparable_data_points(
267-
result_a, result_b, common_x_vals, common_backends, metric
268-
)
269-
270-
return all_comparisons
271-
272-
273-
def _calculate_performance_summary(
274-
all_comparisons: Dict[str, List[MetricComparison]],
275-
common_backends: List[str],
276171
) -> Dict[str, Dict[str, float]]:
277-
"""Calculate performance summary statistics from pre-computed comparisons."""
172+
"""Calculate performance summary statistics."""
278173
summary = {}
279174

175+
# Create result dictionaries for easier lookup
176+
result_dict_a = {x_val: metrics_dict for x_val, metrics_dict in result_a.result}
177+
result_dict_b = {x_val: metrics_dict for x_val, metrics_dict in result_b.result}
178+
280179
for backend in common_backends:
281180
backend_summary = {}
282181

283-
for metric, comparisons in all_comparisons.items():
284-
# Filter for current backend
285-
backend_comparisons = [c for c in comparisons if c.backend == backend]
286-
improvements = [c.improvement_pct for c in backend_comparisons]
182+
for metric in result_a.metrics:
183+
improvements = []
184+
185+
for x_val in common_x_vals:
186+
if backend in result_dict_a[x_val] and backend in result_dict_b[x_val]:
187+
metrics_a = result_dict_a[x_val][backend]
188+
metrics_b = result_dict_b[x_val][backend]
189+
190+
val_a = getattr(metrics_a, metric, None)
191+
val_b = getattr(metrics_b, metric, None)
192+
193+
if val_a is not None and val_b is not None:
194+
# Handle different metric types
195+
if hasattr(val_a, "p50"):
196+
val_a_num = val_a.p50
197+
else:
198+
val_a_num = val_a
199+
200+
if hasattr(val_b, "p50"):
201+
val_b_num = val_b.p50
202+
else:
203+
val_b_num = val_b
204+
205+
if val_a_num != 0:
206+
improvement = ((val_b_num - val_a_num) / val_a_num) * 100
207+
improvements.append(improvement)
287208

288209
if improvements:
289210
backend_summary[metric] = {
@@ -338,14 +259,6 @@ def compare_ab_results(
338259
print("ERROR: No common backends found between configurations")
339260
return
340261

341-
# ============================================================================
342-
# PRE-COMPUTE: Get all comparable data points once
343-
# ============================================================================
344-
all_comparisons = _get_all_comparable_data_points(
345-
result_a, result_b, common_x_vals, common_backends
346-
)
347-
348-
349262
# ============================================================================
350263
# SECTION 1: Configuration Analysis
351264
# ============================================================================
@@ -377,7 +290,9 @@ def compare_ab_results(
377290
print("Performance Summary")
378291
print("-" * 70)
379292

380-
summary = _calculate_performance_summary(all_comparisons, common_backends)
293+
summary = _calculate_performance_summary(
294+
result_a, result_b, common_x_vals, common_backends
295+
)
381296

382297
for backend in common_backends:
383298
print(f"\n{backend}:")
@@ -405,10 +320,8 @@ def compare_ab_results(
405320

406321
x_val_name = REGISTERED_X_VALS.get(result_a.op_name, "x_val")
407322

408-
# Show only metrics that have comparable data
323+
# Show all metrics for detailed comparison
409324
for metric in result_a.metrics:
410-
if metric not in all_comparisons or len(all_comparisons[metric]) == 0:
411-
continue # Skip metrics with no comparable data
412325
print(f"\nMetric: {metric}")
413326
print("Backend".ljust(15), end="")
414327
print(x_val_name.ljust(20), end="")
@@ -417,36 +330,49 @@ def compare_ab_results(
417330
print("Difference".ljust(12))
418331
print("-" * 71)
419332

420-
# Use pre-computed comparisons
421-
comparisons = all_comparisons.get(metric, [])
422-
423-
# Group by backend for display
424-
backend_comparisons = {}
425-
for comp in comparisons:
426-
if comp.backend not in backend_comparisons:
427-
backend_comparisons[comp.backend] = []
428-
backend_comparisons[comp.backend].append(comp)
429-
430333
for backend in common_backends:
431-
if backend not in backend_comparisons:
432-
continue
433-
434334
first_row = True
435-
for comp in backend_comparisons[backend]:
436-
# Format values
437-
if isinstance(comp.val_a, float):
438-
val_a_str = f"{comp.val_a:.3f}"
439-
val_b_str = f"{comp.val_b:.3f}"
440-
else:
441-
val_a_str = str(comp.val_a)
442-
val_b_str = str(comp.val_b)
443-
444-
# Print row
445-
backend_name = backend if first_row else ""
446-
print(
447-
f"{backend_name:<15}{str(comp.x_val):<20}{val_a_str:<12}{val_b_str:<12}{comp.improvement_pct:+5.1f}%"
448-
)
449-
first_row = False
335+
for x_val in common_x_vals:
336+
if (
337+
backend not in result_dict_a[x_val]
338+
or backend not in result_dict_b[x_val]
339+
):
340+
continue
341+
342+
metrics_a = result_dict_a[x_val][backend]
343+
metrics_b = result_dict_b[x_val][backend]
344+
345+
val_a = getattr(metrics_a, metric, None)
346+
val_b = getattr(metrics_b, metric, None)
347+
348+
if val_a is not None and val_b is not None:
349+
# Handle different data types
350+
if hasattr(val_a, "p50"):
351+
val_a_num = val_a.p50
352+
val_b_num = val_b.p50
353+
else:
354+
val_a_num = val_a
355+
val_b_num = val_b
356+
357+
if val_a_num != 0:
358+
diff_pct = ((val_b_num - val_a_num) / val_a_num) * 100
359+
else:
360+
diff_pct = 0
361+
362+
# Format values
363+
if isinstance(val_a_num, float):
364+
val_a_str = f"{val_a_num:.3f}"
365+
val_b_str = f"{val_b_num:.3f}"
366+
else:
367+
val_a_str = str(val_a_num)
368+
val_b_str = str(val_b_num)
369+
370+
# Print row
371+
backend_name = backend if first_row else ""
372+
print(
373+
f"{backend_name:<15}{str(x_val):<20}{val_a_str:<12}{val_b_str:<12}{diff_pct:+5.1f}%"
374+
)
375+
first_row = False
450376

451377
if not first_row: # Only print separator if we printed data
452378
print()

0 commit comments

Comments
 (0)