15
15
16
16
from tritonbench .operators import load_opbench_by_name
17
17
from tritonbench .operators_collection import list_operators_by_collection
18
+ from tritonbench .utils .ab_test import compare_ab_results , run_ab_test
18
19
from tritonbench .utils .env_utils import is_fbcode
19
20
from tritonbench .utils .gpu_utils import gpu_lockdown
20
21
from tritonbench .utils .list_operator_details import list_operator_details
23
24
24
25
from tritonbench .utils .triton_op import BenchmarkOperatorResult
25
26
from tritonbench .utils .tritonparse_utils import tritonparse_init , tritonparse_parse
26
- from tritonbench .utils .ab_test import run_ab_test , compare_ab_results
27
27
28
28
try :
29
29
if is_fbcode ():
34
34
usage_report_logger = lambda * args , ** kwargs : None
35
35
36
36
37
-
38
-
39
37
def _run (args : argparse .Namespace , extra_args : List [str ]) -> BenchmarkOperatorResult :
40
38
if is_loader_op (args .op ):
41
39
Opbench = get_op_loader_bench_cls_by_name (args .op )
@@ -132,23 +130,26 @@ def run(args: List[str] = []):
132
130
# Check if A/B testing mode is enabled
133
131
if args .side_a is not None and args .side_b is not None :
134
132
# A/B testing mode - only support single operator
135
- assert len (ops ) == 1 , "A/B testing validation should have caught multiple operators"
133
+ assert (
134
+ len (ops ) == 1
135
+ ), "A/B testing validation should have caught multiple operators"
136
136
op = ops [0 ]
137
137
args .op = op
138
-
138
+
139
139
print ("[A/B Testing Mode Enabled]" )
140
140
print (f"Operator: { op } " )
141
141
print ()
142
-
142
+
143
143
with gpu_lockdown (args .gpu_lockdown ):
144
144
try :
145
145
result_a , result_b = run_ab_test (args , extra_args , _run )
146
-
146
+
147
147
from tritonbench .utils .ab_test import parse_ab_config
148
+
148
149
config_a_args = parse_ab_config (args .side_a )
149
150
config_b_args = parse_ab_config (args .side_b )
150
151
compare_ab_results (result_a , result_b , config_a_args , config_b_args )
151
-
152
+
152
153
except Exception as e :
153
154
print (f"A/B test failed: { e } " )
154
155
if not args .bypass_fail :
@@ -166,7 +167,7 @@ def run(args: List[str] = []):
166
167
run_in_task (op )
167
168
else :
168
169
_run (args , extra_args )
169
-
170
+
170
171
tritonparse_parse (args .tritonparse )
171
172
172
173
0 commit comments