Skip to content

Commit 5dcd266

Browse files
committed
Enhance A/B testing error handling
Add a bit more try catch blocks
1 parent ff9c86e commit 5dcd266

File tree

1 file changed

+43
-15
lines changed

1 file changed

+43
-15
lines changed

tritonbench/utils/ab_test.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ def update_args_with_global(base_args: argparse.Namespace, global_args: List[str
100100
if value is not None and key not in ['side_a', 'side_b']:
101101
setattr(updated_args, key, value)
102102

103-
except SystemExit:
103+
except SystemExit as e:
104104
# If parsing fails, keep original args
105-
pass
105+
print(f"WARNING: Failed to parse global arguments {global_args}, using original args: {e}")
106+
except Exception as e:
107+
print(f"WARNING: Unexpected error parsing global arguments {global_args}: {e}")
106108

107109
return updated_args
108110

@@ -245,13 +247,16 @@ def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOpe
245247
print("=" * 70)
246248

247249
print("Configuration Differences:")
248-
differences = _analyze_config_differences(config_a_args, config_b_args)
249-
250-
if differences:
251-
for param, (val_a, val_b) in differences.items():
252-
print(f" {param:<15}: {val_a:<15}{val_b}")
253-
else:
254-
print(" No configuration differences detected")
250+
try:
251+
differences = _analyze_config_differences(config_a_args, config_b_args)
252+
253+
if differences:
254+
for param, (val_a, val_b) in differences.items():
255+
print(f" {param:<15}: {val_a:<15}{val_b}")
256+
else:
257+
print(" No configuration differences detected")
258+
except Exception as e:
259+
print(f" ERROR: Failed to analyze configuration differences: {e}")
255260

256261
print(f"\nTest Scope: {len(common_x_vals)} input shapes, {len(common_backends)} backends")
257262
print(f"Metrics: {', '.join(result_a.metrics)}")
@@ -347,8 +352,17 @@ def run_ab_test(base_args: argparse.Namespace, base_extra_args: List[str], _run_
347352
"""Run A/B test with two configurations and return both results."""
348353

349354
# Parse A and B configurations
350-
config_a_args = parse_ab_config(base_args.side_a)
351-
config_b_args = parse_ab_config(base_args.side_b)
355+
try:
356+
config_a_args = parse_ab_config(base_args.side_a)
357+
except ValueError as e:
358+
print(f"ERROR: Failed to parse Side A configuration: {e}")
359+
raise
360+
361+
try:
362+
config_b_args = parse_ab_config(base_args.side_b)
363+
except ValueError as e:
364+
print(f"ERROR: Failed to parse Side B configuration: {e}")
365+
raise
352366

353367
print(f"[A/B Test] Configuration A: {' '.join(config_a_args)}")
354368
print(f"[A/B Test] Configuration B: {' '.join(config_b_args)}")
@@ -376,21 +390,35 @@ def run_ab_test(base_args: argparse.Namespace, base_extra_args: List[str], _run_
376390
extra_args_b = base_extra_args + op_b_args
377391

378392
print("=" * 60)
379-
print(f"Running Configuration A: {' '.join(config_a_args)}")
393+
print(f"Running Side A: {' '.join(config_a_args)}")
380394
if global_a_args:
381395
print(f" Global args: {' '.join(global_a_args)}")
382396
if op_a_args:
383397
print(f" Operator args: {' '.join(op_a_args)}")
384398
print("=" * 60)
385-
result_a = _run_func(args_a, extra_args_a)
399+
400+
try:
401+
result_a = _run_func(args_a, extra_args_a)
402+
if not result_a:
403+
raise RuntimeError("Side A returned empty result")
404+
except Exception as e:
405+
print(f"ERROR: Side A failed to run: {e}")
406+
raise RuntimeError(f"A/B test failed - Side A error: {e}")
386407

387408
print("\n" + "=" * 60)
388-
print(f"Running Configuration B: {' '.join(config_b_args)}")
409+
print(f"Running Side B: {' '.join(config_b_args)}")
389410
if global_b_args:
390411
print(f" Global args: {' '.join(global_b_args)}")
391412
if op_b_args:
392413
print(f" Operator args: {' '.join(op_b_args)}")
393414
print("=" * 60)
394-
result_b = _run_func(args_b, extra_args_b)
415+
416+
try:
417+
result_b = _run_func(args_b, extra_args_b)
418+
if not result_b:
419+
raise RuntimeError("Side B returned empty result")
420+
except Exception as e:
421+
print(f"ERROR: Side B failed to run: {e}")
422+
raise RuntimeError(f"A/B test failed - Side B error: {e}")
395423

396424
return result_a, result_b

0 commit comments

Comments
 (0)