Skip to content

Commit ff9c86e

Browse files
committed
Reformat A/B testing output
Restructure output into three sections: - Configuration Analysis: shows parameter differences between A and B - Performance Summary: displays average improvements per backend/metric - Detailed Comparison: shows all metrics in tabular format per input size
1 parent 667ffd4 commit ff9c86e

File tree

1 file changed

+170
-52
lines changed

1 file changed

+170
-52
lines changed

tritonbench/utils/ab_test.py

Lines changed: 170 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -107,26 +107,113 @@ def update_args_with_global(base_args: argparse.Namespace, global_args: List[str
107107
return updated_args
108108

109109

110+
def _analyze_config_differences(config_a_args: List[str], config_b_args: List[str]) -> Dict[str, Tuple[str, str]]:
111+
"""Analyze differences between two configurations."""
112+
# Parse arguments into dictionaries
113+
def parse_config_to_dict(args):
114+
config_dict = {}
115+
i = 0
116+
while i < len(args):
117+
if args[i].startswith('--'):
118+
key = args[i][2:] # Remove --
119+
if '=' in args[i]:
120+
# Format: --key=value
121+
key, value = args[i][2:].split('=', 1)
122+
config_dict[key] = value
123+
i += 1
124+
elif i + 1 < len(args) and not args[i + 1].startswith('-'):
125+
# Format: --key value
126+
config_dict[key] = args[i + 1]
127+
i += 2
128+
else:
129+
# Flag without value
130+
config_dict[key] = "True"
131+
i += 1
132+
else:
133+
i += 1
134+
return config_dict
135+
136+
config_a = parse_config_to_dict(config_a_args)
137+
config_b = parse_config_to_dict(config_b_args)
138+
139+
# Find differences
140+
differences = {}
141+
all_keys = set(config_a.keys()) | set(config_b.keys())
142+
143+
for key in all_keys:
144+
val_a = config_a.get(key, "default")
145+
val_b = config_b.get(key, "default")
146+
if val_a != val_b:
147+
differences[key] = (val_a, val_b)
148+
149+
return differences
150+
151+
152+
def _calculate_performance_summary(result_a: BenchmarkOperatorResult, result_b: BenchmarkOperatorResult,
153+
common_x_vals: List, common_backends: List[str]) -> Dict[str, Dict[str, float]]:
154+
"""Calculate performance summary statistics."""
155+
summary = {}
156+
157+
# Create result dictionaries for easier lookup
158+
result_dict_a = {x_val: metrics_dict for x_val, metrics_dict in result_a.result}
159+
result_dict_b = {x_val: metrics_dict for x_val, metrics_dict in result_b.result}
160+
161+
for backend in common_backends:
162+
backend_summary = {}
163+
164+
for metric in result_a.metrics:
165+
improvements = []
166+
167+
for x_val in common_x_vals:
168+
if (backend in result_dict_a[x_val] and backend in result_dict_b[x_val]):
169+
metrics_a = result_dict_a[x_val][backend]
170+
metrics_b = result_dict_b[x_val][backend]
171+
172+
val_a = getattr(metrics_a, metric, None)
173+
val_b = getattr(metrics_b, metric, None)
174+
175+
if val_a is not None and val_b is not None:
176+
# Handle different metric types
177+
if hasattr(val_a, 'p50'):
178+
val_a_num = val_a.p50
179+
else:
180+
val_a_num = val_a
181+
182+
if hasattr(val_b, 'p50'):
183+
val_b_num = val_b.p50
184+
else:
185+
val_b_num = val_b
186+
187+
if val_a_num != 0:
188+
improvement = ((val_b_num - val_a_num) / val_a_num) * 100
189+
improvements.append(improvement)
190+
191+
if improvements:
192+
backend_summary[metric] = {
193+
'avg_improvement': sum(improvements) / len(improvements),
194+
'min_improvement': min(improvements),
195+
'max_improvement': max(improvements),
196+
'count': len(improvements)
197+
}
198+
199+
summary[backend] = backend_summary
200+
201+
return summary
202+
203+
110204
def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOperatorResult,
111205
config_a_args: List[str], config_b_args: List[str]):
112-
"""Compare A/B test results and display formatted comparison."""
206+
"""Compare A/B test results"""
113207
if not result_a or not result_b:
114208
print("\n[A/B Comparison] ERROR: One or both results are invalid")
115209
return
116210

117-
print("\n" + "=" * 80)
118-
print(f"[A/B Test Results Comparison] - {result_a.op_name}")
119-
print("=" * 80)
120-
print(f"Configuration A: {' '.join(config_a_args)}")
121-
print(f"Configuration B: {' '.join(config_b_args)}")
122-
print()
123-
124211
# Check if both results have data
125212
if not result_a.result or not result_b.result:
126213
print("ERROR: No benchmark data available for comparison")
127214
return
128215

129-
# Get all x_vals (input shapes) that are common to both results
216+
# Get common data for analysis
130217
x_vals_a = {x_val for x_val, _ in result_a.result}
131218
x_vals_b = {x_val for x_val, _ in result_b.result}
132219
common_x_vals = sorted(x_vals_a.intersection(x_vals_b))
@@ -135,11 +222,10 @@ def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOpe
135222
print("ERROR: No common input shapes found between configurations")
136223
return
137224

138-
# Create result dictionaries for easier lookup
225+
# Get common backends
139226
result_dict_a = {x_val: metrics_dict for x_val, metrics_dict in result_a.result}
140227
result_dict_b = {x_val: metrics_dict for x_val, metrics_dict in result_b.result}
141228

142-
# Get all backends that are common to both results
143229
all_backends_a = set()
144230
all_backends_b = set()
145231
for x_val in common_x_vals:
@@ -151,58 +237,89 @@ def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOpe
151237
print("ERROR: No common backends found between configurations")
152238
return
153239

154-
print(f"Comparing {len(common_x_vals)} input shapes across {len(common_backends)} backends")
240+
# ============================================================================
241+
# SECTION 1: Configuration Analysis
242+
# ============================================================================
243+
print("\n" + "=" * 70)
244+
print(f"A/B Test Results: {result_a.op_name}")
245+
print("=" * 70)
246+
247+
print("Configuration Differences:")
248+
differences = _analyze_config_differences(config_a_args, config_b_args)
249+
250+
if differences:
251+
for param, (val_a, val_b) in differences.items():
252+
print(f" {param:<15}: {val_a:<15}{val_b}")
253+
else:
254+
print(" No configuration differences detected")
255+
256+
print(f"\nTest Scope: {len(common_x_vals)} input shapes, {len(common_backends)} backends")
155257
print(f"Metrics: {', '.join(result_a.metrics)}")
156-
print()
157258

158-
# Create comparison table
159-
x_val_name = REGISTERED_X_VALS.get(result_a.op_name, "x_val")
259+
# ============================================================================
260+
# SECTION 2: Performance Summary
261+
# ============================================================================
262+
print("\n" + "-" * 70)
263+
print("Performance Summary")
264+
print("-" * 70)
265+
266+
summary = _calculate_performance_summary(result_a, result_b, common_x_vals, common_backends)
160267

161268
for backend in common_backends:
162-
print(f"Backend: {backend}")
163-
print("-" * 60)
269+
print(f"\n{backend}:")
270+
backend_data = summary.get(backend, {})
164271

165-
# Create table headers
166-
headers = [x_val_name]
167-
for metric in result_a.metrics:
168-
headers.extend([f"{metric}_A", f"{metric}_B", f"{metric}_diff%"])
272+
if not backend_data:
273+
print(" No comparable data")
274+
continue
169275

170-
# Print headers
171-
print("{:<15} ".format(headers[0]), end="")
172-
for i in range(1, len(headers)):
173-
print("{:<12} ".format(headers[i]), end="")
174-
print()
175-
print("-" * (15 + 12 * (len(headers) - 1)))
276+
for metric, stats in backend_data.items():
277+
avg_improvement = stats['avg_improvement']
278+
min_improvement = stats['min_improvement']
279+
max_improvement = stats['max_improvement']
280+
281+
print(f" {metric:<12}: {avg_improvement:+5.1f}% avg [{min_improvement:+.1f}% to {max_improvement:+.1f}%]")
282+
283+
# ============================================================================
284+
# SECTION 3: Detailed Comparison (Compact)
285+
# ============================================================================
286+
print("\n" + "-" * 70)
287+
print("Detailed Comparison")
288+
print("-" * 70)
289+
290+
x_val_name = REGISTERED_X_VALS.get(result_a.op_name, "x_val")
291+
292+
# Show all metrics for detailed comparison
293+
for metric in result_a.metrics:
294+
print(f"\nMetric: {metric}")
295+
print("Backend".ljust(15), end="")
296+
print(x_val_name.ljust(20), end="")
297+
print("Config A".ljust(12), end="")
298+
print("Config B".ljust(12), end="")
299+
print("Difference".ljust(12))
300+
print("-" * 71)
176301

177-
# Print data rows
178-
for x_val in common_x_vals:
179-
if backend not in result_dict_a[x_val] or backend not in result_dict_b[x_val]:
180-
continue
302+
for backend in common_backends:
303+
first_row = True
304+
for x_val in common_x_vals:
305+
if backend not in result_dict_a[x_val] or backend not in result_dict_b[x_val]:
306+
continue
307+
308+
metrics_a = result_dict_a[x_val][backend]
309+
metrics_b = result_dict_b[x_val][backend]
181310

182-
metrics_a = result_dict_a[x_val][backend]
183-
metrics_b = result_dict_b[x_val][backend]
184-
185-
# Print x_val
186-
print("{:<15} ".format(str(x_val)), end="")
187-
188-
# Print metrics comparisons
189-
for metric in result_a.metrics:
190311
val_a = getattr(metrics_a, metric, None)
191312
val_b = getattr(metrics_b, metric, None)
192313

193314
if val_a is not None and val_b is not None:
194-
# Handle latency objects
315+
# Handle different data types
195316
if hasattr(val_a, 'p50'):
196317
val_a_num = val_a.p50
197-
else:
198-
val_a_num = val_a
199-
200-
if hasattr(val_b, 'p50'):
201318
val_b_num = val_b.p50
202319
else:
320+
val_a_num = val_a
203321
val_b_num = val_b
204322

205-
# Calculate percentage difference
206323
if val_a_num != 0:
207324
diff_pct = ((val_b_num - val_a_num) / val_a_num) * 100
208325
else:
@@ -216,13 +333,14 @@ def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOpe
216333
val_a_str = str(val_a_num)
217334
val_b_str = str(val_b_num)
218335

219-
print("{:<12} {:<12} {:<12} ".format(
220-
val_a_str, val_b_str, f"{diff_pct:+.1f}%"
221-
), end="")
222-
else:
223-
print("{:<12} {:<12} {:<12} ".format("N/A", "N/A", "N/A"), end="")
224-
print()
225-
print()
336+
# Print row
337+
backend_name = backend if first_row else ""
338+
print(f"{backend_name:<15}{str(x_val):<20}{val_a_str:<12}{val_b_str:<12}{diff_pct:+5.1f}%")
339+
first_row = False
340+
341+
if not first_row: # Only print separator if we printed data
342+
print()
343+
226344

227345

228346
def run_ab_test(base_args: argparse.Namespace, base_extra_args: List[str], _run_func) -> Tuple[BenchmarkOperatorResult, BenchmarkOperatorResult]:

0 commit comments

Comments
 (0)