Skip to content

Commit 667ffd4

Browse files
committed
A/B testing results collection and comparison output
- Create dedicated ab_test.py module for A/B testing functions - Add detailed results comparison with performance calculations - Support tabular output across multiple backends and input shapes - Clean up run.py by moving A/B logic to separate module
1 parent c7fb992 commit 667ffd4

File tree

2 files changed

+287
-130
lines changed

2 files changed

+287
-130
lines changed

run.py

Lines changed: 9 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from tritonbench.utils.triton_op import BenchmarkOperatorResult
2525
from tritonbench.utils.tritonparse_utils import tritonparse_init, tritonparse_parse
26+
from tritonbench.utils.ab_test import run_ab_test, compare_ab_results
2627

2728
try:
2829
if is_fbcode():
@@ -33,131 +34,6 @@
3334
usage_report_logger = lambda *args, **kwargs: None
3435

3536

36-
def parse_ab_config(config_str: str) -> List[str]:
37-
"""Parse A/B configuration string into list of arguments."""
38-
if not config_str:
39-
return []
40-
41-
# Use shlex to properly handle quoted arguments
42-
try:
43-
return shlex.split(config_str)
44-
except ValueError as e:
45-
raise ValueError(f"Invalid configuration string: {config_str}. Error: {e}")
46-
47-
48-
def separate_global_and_op_args(config_args: List[str]) -> Tuple[List[str], List[str]]:
49-
"""Separate global tritonbench args from operator-specific args."""
50-
if not config_args:
51-
return [], []
52-
53-
# Create a temporary parser with only global arguments to identify which args are global
54-
temp_parser = get_parser()
55-
56-
# Parse the config args to separate global from operator-specific
57-
try:
58-
# Use parse_known_args to get global args and remaining (operator) args
59-
global_args, op_args = temp_parser.parse_known_args(config_args)
60-
61-
# Simple approach: just return the input config_args that were recognized as global
62-
# and the remaining op_args
63-
global_arg_list = []
64-
i = 0
65-
while i < len(config_args):
66-
arg = config_args[i]
67-
if arg.startswith('--'):
68-
# Check if this arg was consumed by the global parser
69-
arg_name = arg[2:].replace('-', '_')
70-
if hasattr(global_args, arg_name) and arg_name not in ['side_a', 'side_b']:
71-
global_arg_list.append(arg)
72-
# If it's not a flag and has a value, include the value too
73-
if i + 1 < len(config_args) and not config_args[i + 1].startswith('-'):
74-
i += 1
75-
global_arg_list.append(config_args[i])
76-
i += 1
77-
78-
return global_arg_list, op_args
79-
80-
except SystemExit:
81-
# If parsing fails, treat all args as operator-specific
82-
return [], config_args
83-
84-
85-
def update_args_with_global(base_args: argparse.Namespace, global_args: List[str]) -> argparse.Namespace:
86-
"""Update base args with global arguments from A/B config."""
87-
if not global_args:
88-
return argparse.Namespace(**vars(base_args))
89-
90-
# Create a copy of base args
91-
updated_args = argparse.Namespace(**vars(base_args))
92-
93-
# Parse global args and update the namespace
94-
temp_parser = get_parser()
95-
try:
96-
parsed_globals, _ = temp_parser.parse_known_args(global_args)
97-
98-
# Update the namespace with new global values
99-
for key, value in vars(parsed_globals).items():
100-
if value is not None and key not in ['side_a', 'side_b']:
101-
setattr(updated_args, key, value)
102-
103-
except SystemExit:
104-
# If parsing fails, keep original args
105-
pass
106-
107-
return updated_args
108-
109-
110-
def run_ab_test(base_args: argparse.Namespace, base_extra_args: List[str]) -> Tuple[BenchmarkOperatorResult, BenchmarkOperatorResult]:
111-
"""Run A/B test with two configurations and return both results."""
112-
113-
# Parse A and B configurations
114-
config_a_args = parse_ab_config(base_args.side_a)
115-
config_b_args = parse_ab_config(base_args.side_b)
116-
117-
print(f"[A/B Test] Configuration A: {' '.join(config_a_args)}")
118-
print(f"[A/B Test] Configuration B: {' '.join(config_b_args)}")
119-
120-
# Separate global and operator-specific arguments
121-
global_a_args, op_a_args = separate_global_and_op_args(config_a_args)
122-
global_b_args, op_b_args = separate_global_and_op_args(config_b_args)
123-
124-
if global_a_args:
125-
print(f"[A/B Test] Global args A: {' '.join(global_a_args)}")
126-
if op_a_args:
127-
print(f"[A/B Test] Operator args A: {' '.join(op_a_args)}")
128-
if global_b_args:
129-
print(f"[A/B Test] Global args B: {' '.join(global_b_args)}")
130-
if op_b_args:
131-
print(f"[A/B Test] Operator args B: {' '.join(op_b_args)}")
132-
print()
133-
134-
# Update args with global parameters
135-
args_a = update_args_with_global(base_args, global_a_args)
136-
args_b = update_args_with_global(base_args, global_b_args)
137-
138-
# Combine extra_args with operator-specific args only
139-
extra_args_a = base_extra_args + op_a_args
140-
extra_args_b = base_extra_args + op_b_args
141-
142-
print("=" * 60)
143-
print(f"Running Configuration A: {' '.join(config_a_args)}")
144-
if global_a_args:
145-
print(f" Global args: {' '.join(global_a_args)}")
146-
if op_a_args:
147-
print(f" Operator args: {' '.join(op_a_args)}")
148-
print("=" * 60)
149-
result_a = _run(args_a, extra_args_a)
150-
151-
print("\n" + "=" * 60)
152-
print(f"Running Configuration B: {' '.join(config_b_args)}")
153-
if global_b_args:
154-
print(f" Global args: {' '.join(global_b_args)}")
155-
if op_b_args:
156-
print(f" Operator args: {' '.join(op_b_args)}")
157-
print("=" * 60)
158-
result_b = _run(args_b, extra_args_b)
159-
160-
return result_a, result_b
16137

16238

16339
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
@@ -266,11 +142,14 @@ def run(args: List[str] = []):
266142

267143
with gpu_lockdown(args.gpu_lockdown):
268144
try:
269-
result_a, result_b = run_ab_test(args, extra_args)
270-
# TODO: Phase 3 - Implement A/B comparison output
271-
print("\n[A/B Test Results]")
272-
print("Configuration A result:", result_a.benchmark_name if result_a else "Failed")
273-
print("Configuration B result:", result_b.benchmark_name if result_b else "Failed")
145+
result_a, result_b = run_ab_test(args, extra_args, _run)
146+
147+
# Phase 3: Implement A/B comparison output
148+
from tritonbench.utils.ab_test import parse_ab_config
149+
config_a_args = parse_ab_config(args.side_a)
150+
config_b_args = parse_ab_config(args.side_b)
151+
compare_ab_results(result_a, result_b, config_a_args, config_b_args)
152+
274153
except Exception as e:
275154
print(f"A/B test failed: {e}")
276155
if not args.bypass_fail:

0 commit comments

Comments
 (0)