Skip to content

Commit c7fb992

Browse files
committed
Implement A/B testing core functionality
Core A/B Testing Features: - Add config parsing with proper argument handling - Implement run_ab_test function for dual execution - Support global and operator-specific parameters - Basic A/B test results display Example usage: python run.py --op flex_attention --side-a='--precision fp16 --max-autotune' --side-b='--precision bf16 --dynamic'
1 parent ba3fd8f commit c7fb992

File tree

2 files changed

+164
-13
lines changed

2 files changed

+164
-13
lines changed

run.py

Lines changed: 164 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import argparse
99
import os
10+
import shlex
1011
import sys
11-
from typing import List
12+
from typing import List, Tuple
1213

1314
from tritonbench.operator_loader import get_op_loader_bench_cls_by_name, is_loader_op
1415

@@ -32,6 +33,133 @@
3233
usage_report_logger = lambda *args, **kwargs: None
3334

3435

36+
def parse_ab_config(config_str: str) -> List[str]:
37+
"""Parse A/B configuration string into list of arguments."""
38+
if not config_str:
39+
return []
40+
41+
# Use shlex to properly handle quoted arguments
42+
try:
43+
return shlex.split(config_str)
44+
except ValueError as e:
45+
raise ValueError(f"Invalid configuration string: {config_str}. Error: {e}")
46+
47+
48+
def separate_global_and_op_args(config_args: List[str]) -> Tuple[List[str], List[str]]:
49+
"""Separate global tritonbench args from operator-specific args."""
50+
if not config_args:
51+
return [], []
52+
53+
# Create a temporary parser with only global arguments to identify which args are global
54+
temp_parser = get_parser()
55+
56+
# Parse the config args to separate global from operator-specific
57+
try:
58+
# Use parse_known_args to get global args and remaining (operator) args
59+
global_args, op_args = temp_parser.parse_known_args(config_args)
60+
61+
# Simple approach: just return the input config_args that were recognized as global
62+
# and the remaining op_args
63+
global_arg_list = []
64+
i = 0
65+
while i < len(config_args):
66+
arg = config_args[i]
67+
if arg.startswith('--'):
68+
# Check if this arg was consumed by the global parser
69+
arg_name = arg[2:].replace('-', '_')
70+
if hasattr(global_args, arg_name) and arg_name not in ['side_a', 'side_b']:
71+
global_arg_list.append(arg)
72+
# If it's not a flag and has a value, include the value too
73+
if i + 1 < len(config_args) and not config_args[i + 1].startswith('-'):
74+
i += 1
75+
global_arg_list.append(config_args[i])
76+
i += 1
77+
78+
return global_arg_list, op_args
79+
80+
except SystemExit:
81+
# If parsing fails, treat all args as operator-specific
82+
return [], config_args
83+
84+
85+
def update_args_with_global(base_args: argparse.Namespace, global_args: List[str]) -> argparse.Namespace:
86+
"""Update base args with global arguments from A/B config."""
87+
if not global_args:
88+
return argparse.Namespace(**vars(base_args))
89+
90+
# Create a copy of base args
91+
updated_args = argparse.Namespace(**vars(base_args))
92+
93+
# Parse global args and update the namespace
94+
temp_parser = get_parser()
95+
try:
96+
parsed_globals, _ = temp_parser.parse_known_args(global_args)
97+
98+
# Update the namespace with new global values
99+
for key, value in vars(parsed_globals).items():
100+
if value is not None and key not in ['side_a', 'side_b']:
101+
setattr(updated_args, key, value)
102+
103+
except SystemExit:
104+
# If parsing fails, keep original args
105+
pass
106+
107+
return updated_args
108+
109+
110+
def run_ab_test(base_args: argparse.Namespace, base_extra_args: List[str]) -> Tuple[BenchmarkOperatorResult, BenchmarkOperatorResult]:
111+
"""Run A/B test with two configurations and return both results."""
112+
113+
# Parse A and B configurations
114+
config_a_args = parse_ab_config(base_args.side_a)
115+
config_b_args = parse_ab_config(base_args.side_b)
116+
117+
print(f"[A/B Test] Configuration A: {' '.join(config_a_args)}")
118+
print(f"[A/B Test] Configuration B: {' '.join(config_b_args)}")
119+
120+
# Separate global and operator-specific arguments
121+
global_a_args, op_a_args = separate_global_and_op_args(config_a_args)
122+
global_b_args, op_b_args = separate_global_and_op_args(config_b_args)
123+
124+
if global_a_args:
125+
print(f"[A/B Test] Global args A: {' '.join(global_a_args)}")
126+
if op_a_args:
127+
print(f"[A/B Test] Operator args A: {' '.join(op_a_args)}")
128+
if global_b_args:
129+
print(f"[A/B Test] Global args B: {' '.join(global_b_args)}")
130+
if op_b_args:
131+
print(f"[A/B Test] Operator args B: {' '.join(op_b_args)}")
132+
print()
133+
134+
# Update args with global parameters
135+
args_a = update_args_with_global(base_args, global_a_args)
136+
args_b = update_args_with_global(base_args, global_b_args)
137+
138+
# Combine extra_args with operator-specific args only
139+
extra_args_a = base_extra_args + op_a_args
140+
extra_args_b = base_extra_args + op_b_args
141+
142+
print("=" * 60)
143+
print(f"Running Configuration A: {' '.join(config_a_args)}")
144+
if global_a_args:
145+
print(f" Global args: {' '.join(global_a_args)}")
146+
if op_a_args:
147+
print(f" Operator args: {' '.join(op_a_args)}")
148+
print("=" * 60)
149+
result_a = _run(args_a, extra_args_a)
150+
151+
print("\n" + "=" * 60)
152+
print(f"Running Configuration B: {' '.join(config_b_args)}")
153+
if global_b_args:
154+
print(f" Global args: {' '.join(global_b_args)}")
155+
if op_b_args:
156+
print(f" Operator args: {' '.join(op_b_args)}")
157+
print("=" * 60)
158+
result_b = _run(args_b, extra_args_b)
159+
160+
return result_a, result_b
161+
162+
35163
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
36164
if is_loader_op(args.op):
37165
Opbench = get_op_loader_bench_cls_by_name(args.op)
@@ -125,17 +253,42 @@ def run(args: List[str] = []):
125253
)
126254
return
127255

128-
# Force isolation in subprocess if testing more than one op.
129-
if len(ops) >= 2:
130-
args.isolate = True
256+
# Check if A/B testing mode is enabled
257+
if args.side_a is not None and args.side_b is not None:
258+
# A/B testing mode - only support single operator
259+
assert len(ops) == 1, "A/B testing validation should have caught multiple operators"
260+
op = ops[0]
261+
args.op = op
262+
263+
print("[A/B Testing Mode Enabled]")
264+
print(f"Operator: {op}")
265+
print()
266+
267+
with gpu_lockdown(args.gpu_lockdown):
268+
try:
269+
result_a, result_b = run_ab_test(args, extra_args)
270+
# TODO: Phase 3 - Implement A/B comparison output
271+
print("\n[A/B Test Results]")
272+
print("Configuration A result:", result_a.benchmark_name if result_a else "Failed")
273+
print("Configuration B result:", result_b.benchmark_name if result_b else "Failed")
274+
except Exception as e:
275+
print(f"A/B test failed: {e}")
276+
if not args.bypass_fail:
277+
raise
278+
else:
279+
# Normal mode
280+
# Force isolation in subprocess if testing more than one op.
281+
if len(ops) >= 2:
282+
args.isolate = True
131283

132-
with gpu_lockdown(args.gpu_lockdown):
133-
for op in ops:
134-
args.op = op
135-
if args.isolate:
136-
run_in_task(op)
137-
else:
138-
_run(args, extra_args)
284+
with gpu_lockdown(args.gpu_lockdown):
285+
for op in ops:
286+
args.op = op
287+
if args.isolate:
288+
run_in_task(op)
289+
else:
290+
_run(args, extra_args)
291+
139292
tritonparse_parse(args.tritonparse)
140293

141294

tritonbench/utils/parser.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,15 +250,13 @@ def get_parser(args=None):
250250
# A/B Testing parameters
251251
parser.add_argument(
252252
"--side-a",
253-
"-a",
254253
type=str,
255254
default=None,
256255
help="Configuration A for A/B testing. Specify operator-specific arguments as a string. "
257256
"Example: '--side-a \"--max-autotune --dynamic\"'",
258257
)
259258
parser.add_argument(
260259
"--side-b",
261-
"-b",
262260
type=str,
263261
default=None,
264262
help="Configuration B for A/B testing. Specify operator-specific arguments as a string. "

0 commit comments

Comments
 (0)