|
7 | 7 |
|
8 | 8 | import argparse
|
9 | 9 | import os
|
| 10 | +import shlex |
10 | 11 | import sys
|
11 |
| -from typing import List |
| 12 | +from typing import List, Tuple |
12 | 13 |
|
13 | 14 | from tritonbench.operator_loader import get_op_loader_bench_cls_by_name, is_loader_op
|
14 | 15 |
|
|
32 | 33 | usage_report_logger = lambda *args, **kwargs: None
|
33 | 34 |
|
34 | 35 |
|
| 36 | +def parse_ab_config(config_str: str) -> List[str]: |
| 37 | + """Parse A/B configuration string into list of arguments.""" |
| 38 | + if not config_str: |
| 39 | + return [] |
| 40 | + |
| 41 | + # Use shlex to properly handle quoted arguments |
| 42 | + try: |
| 43 | + return shlex.split(config_str) |
| 44 | + except ValueError as e: |
| 45 | + raise ValueError(f"Invalid configuration string: {config_str}. Error: {e}") |
| 46 | + |
| 47 | + |
| 48 | +def separate_global_and_op_args(config_args: List[str]) -> Tuple[List[str], List[str]]: |
| 49 | + """Separate global tritonbench args from operator-specific args.""" |
| 50 | + if not config_args: |
| 51 | + return [], [] |
| 52 | + |
| 53 | + # Create a temporary parser with only global arguments to identify which args are global |
| 54 | + temp_parser = get_parser() |
| 55 | + |
| 56 | + # Parse the config args to separate global from operator-specific |
| 57 | + try: |
| 58 | + # Use parse_known_args to get global args and remaining (operator) args |
| 59 | + global_args, op_args = temp_parser.parse_known_args(config_args) |
| 60 | + |
| 61 | + # Simple approach: just return the input config_args that were recognized as global |
| 62 | + # and the remaining op_args |
| 63 | + global_arg_list = [] |
| 64 | + i = 0 |
| 65 | + while i < len(config_args): |
| 66 | + arg = config_args[i] |
| 67 | + if arg.startswith('--'): |
| 68 | + # Check if this arg was consumed by the global parser |
| 69 | + arg_name = arg[2:].replace('-', '_') |
| 70 | + if hasattr(global_args, arg_name) and arg_name not in ['side_a', 'side_b']: |
| 71 | + global_arg_list.append(arg) |
| 72 | + # If it's not a flag and has a value, include the value too |
| 73 | + if i + 1 < len(config_args) and not config_args[i + 1].startswith('-'): |
| 74 | + i += 1 |
| 75 | + global_arg_list.append(config_args[i]) |
| 76 | + i += 1 |
| 77 | + |
| 78 | + return global_arg_list, op_args |
| 79 | + |
| 80 | + except SystemExit: |
| 81 | + # If parsing fails, treat all args as operator-specific |
| 82 | + return [], config_args |
| 83 | + |
| 84 | + |
| 85 | +def update_args_with_global(base_args: argparse.Namespace, global_args: List[str]) -> argparse.Namespace: |
| 86 | + """Update base args with global arguments from A/B config.""" |
| 87 | + if not global_args: |
| 88 | + return argparse.Namespace(**vars(base_args)) |
| 89 | + |
| 90 | + # Create a copy of base args |
| 91 | + updated_args = argparse.Namespace(**vars(base_args)) |
| 92 | + |
| 93 | + # Parse global args and update the namespace |
| 94 | + temp_parser = get_parser() |
| 95 | + try: |
| 96 | + parsed_globals, _ = temp_parser.parse_known_args(global_args) |
| 97 | + |
| 98 | + # Update the namespace with new global values |
| 99 | + for key, value in vars(parsed_globals).items(): |
| 100 | + if value is not None and key not in ['side_a', 'side_b']: |
| 101 | + setattr(updated_args, key, value) |
| 102 | + |
| 103 | + except SystemExit: |
| 104 | + # If parsing fails, keep original args |
| 105 | + pass |
| 106 | + |
| 107 | + return updated_args |
| 108 | + |
| 109 | + |
| 110 | +def run_ab_test(base_args: argparse.Namespace, base_extra_args: List[str]) -> Tuple[BenchmarkOperatorResult, BenchmarkOperatorResult]: |
| 111 | + """Run A/B test with two configurations and return both results.""" |
| 112 | + |
| 113 | + # Parse A and B configurations |
| 114 | + config_a_args = parse_ab_config(base_args.side_a) |
| 115 | + config_b_args = parse_ab_config(base_args.side_b) |
| 116 | + |
| 117 | + print(f"[A/B Test] Configuration A: {' '.join(config_a_args)}") |
| 118 | + print(f"[A/B Test] Configuration B: {' '.join(config_b_args)}") |
| 119 | + |
| 120 | + # Separate global and operator-specific arguments |
| 121 | + global_a_args, op_a_args = separate_global_and_op_args(config_a_args) |
| 122 | + global_b_args, op_b_args = separate_global_and_op_args(config_b_args) |
| 123 | + |
| 124 | + if global_a_args: |
| 125 | + print(f"[A/B Test] Global args A: {' '.join(global_a_args)}") |
| 126 | + if op_a_args: |
| 127 | + print(f"[A/B Test] Operator args A: {' '.join(op_a_args)}") |
| 128 | + if global_b_args: |
| 129 | + print(f"[A/B Test] Global args B: {' '.join(global_b_args)}") |
| 130 | + if op_b_args: |
| 131 | + print(f"[A/B Test] Operator args B: {' '.join(op_b_args)}") |
| 132 | + print() |
| 133 | + |
| 134 | + # Update args with global parameters |
| 135 | + args_a = update_args_with_global(base_args, global_a_args) |
| 136 | + args_b = update_args_with_global(base_args, global_b_args) |
| 137 | + |
| 138 | + # Combine extra_args with operator-specific args only |
| 139 | + extra_args_a = base_extra_args + op_a_args |
| 140 | + extra_args_b = base_extra_args + op_b_args |
| 141 | + |
| 142 | + print("=" * 60) |
| 143 | + print(f"Running Configuration A: {' '.join(config_a_args)}") |
| 144 | + if global_a_args: |
| 145 | + print(f" Global args: {' '.join(global_a_args)}") |
| 146 | + if op_a_args: |
| 147 | + print(f" Operator args: {' '.join(op_a_args)}") |
| 148 | + print("=" * 60) |
| 149 | + result_a = _run(args_a, extra_args_a) |
| 150 | + |
| 151 | + print("\n" + "=" * 60) |
| 152 | + print(f"Running Configuration B: {' '.join(config_b_args)}") |
| 153 | + if global_b_args: |
| 154 | + print(f" Global args: {' '.join(global_b_args)}") |
| 155 | + if op_b_args: |
| 156 | + print(f" Operator args: {' '.join(op_b_args)}") |
| 157 | + print("=" * 60) |
| 158 | + result_b = _run(args_b, extra_args_b) |
| 159 | + |
| 160 | + return result_a, result_b |
| 161 | + |
| 162 | + |
35 | 163 | def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
|
36 | 164 | if is_loader_op(args.op):
|
37 | 165 | Opbench = get_op_loader_bench_cls_by_name(args.op)
|
@@ -125,17 +253,42 @@ def run(args: List[str] = []):
|
125 | 253 | )
|
126 | 254 | return
|
127 | 255 |
|
128 |
| - # Force isolation in subprocess if testing more than one op. |
129 |
| - if len(ops) >= 2: |
130 |
| - args.isolate = True |
| 256 | + # Check if A/B testing mode is enabled |
| 257 | + if args.side_a is not None and args.side_b is not None: |
| 258 | + # A/B testing mode - only support single operator |
| 259 | + assert len(ops) == 1, "A/B testing validation should have caught multiple operators" |
| 260 | + op = ops[0] |
| 261 | + args.op = op |
| 262 | + |
| 263 | + print("[A/B Testing Mode Enabled]") |
| 264 | + print(f"Operator: {op}") |
| 265 | + print() |
| 266 | + |
| 267 | + with gpu_lockdown(args.gpu_lockdown): |
| 268 | + try: |
| 269 | + result_a, result_b = run_ab_test(args, extra_args) |
| 270 | + # TODO: Phase 3 - Implement A/B comparison output |
| 271 | + print("\n[A/B Test Results]") |
| 272 | + print("Configuration A result:", result_a.benchmark_name if result_a else "Failed") |
| 273 | + print("Configuration B result:", result_b.benchmark_name if result_b else "Failed") |
| 274 | + except Exception as e: |
| 275 | + print(f"A/B test failed: {e}") |
| 276 | + if not args.bypass_fail: |
| 277 | + raise |
| 278 | + else: |
| 279 | + # Normal mode |
| 280 | + # Force isolation in subprocess if testing more than one op. |
| 281 | + if len(ops) >= 2: |
| 282 | + args.isolate = True |
131 | 283 |
|
132 |
| - with gpu_lockdown(args.gpu_lockdown): |
133 |
| - for op in ops: |
134 |
| - args.op = op |
135 |
| - if args.isolate: |
136 |
| - run_in_task(op) |
137 |
| - else: |
138 |
| - _run(args, extra_args) |
| 284 | + with gpu_lockdown(args.gpu_lockdown): |
| 285 | + for op in ops: |
| 286 | + args.op = op |
| 287 | + if args.isolate: |
| 288 | + run_in_task(op) |
| 289 | + else: |
| 290 | + _run(args, extra_args) |
| 291 | + |
139 | 292 | tritonparse_parse(args.tritonparse)
|
140 | 293 |
|
141 | 294 |
|
|
0 commit comments