4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import json
7
8
import logging
9
+ from collections import defaultdict
10
+ from pathlib import Path
11
+ from typing import Any , Dict , List , Optional , Tuple
8
12
9
13
import torch
10
14
15
+
11
16
try :
12
- import triton .testing
17
+ if torch .cuda .is_available ():
18
+ import triton .testing
13
19
14
- TRITON_AVAILABLE = True
20
+ TRITON_AVAILABLE = True
21
+ else :
22
+ TRITON_AVAILABLE = False
15
23
except ImportError :
16
24
TRITON_AVAILABLE = False
17
25
18
- from BackendBench .utils import serialize_args , uses_cuda_stream
26
+ from BackendBench .utils import serialize_args , uses_cuda_stream , compute_errors
19
27
20
28
logger = logging .getLogger (__name__ )
21
29
@@ -31,34 +39,71 @@ def format_exception(e, op, args, kwargs):
31
39
return EXC_MSG .format (op = op_name , args = serialize_args (args , kwargs ), exc = e )
32
40
33
41
34
- def allclose (a , b ):
35
- if isinstance (a , torch .Tensor ):
36
- torch .testing .assert_close (a , b , equal_nan = True , atol = 1e-2 , rtol = 1e-2 )
42
+ def _allclose (a , b , atol = 1e-2 , rtol = 1e-2 ):
43
+ # using a stack to avoid recursion overflow issues
44
+ stack = [(a , b )]
45
+
46
+ while len (stack ) > 0 :
47
+ curr_a , curr_b = stack .pop ()
48
+
49
+ if isinstance (curr_a , torch .Tensor ):
50
+ torch .testing .assert_close (curr_a , curr_b , equal_nan = True , atol = atol , rtol = rtol )
51
+ elif isinstance (curr_a , (list , tuple )):
52
+ assert len (curr_a ) == len (curr_b )
53
+ # Add pairs to stack in reverse order to maintain left-to-right checking
54
+ stack .extend (reversed (list (zip (curr_a , curr_b ))))
55
+ else :
56
+ assert curr_a == curr_b
57
+
58
+
59
+ def allclose (a , b , atol = 1e-2 , rtol = 1e-2 ):
60
+ try :
61
+ _allclose (a , b )
37
62
return True
38
- if isinstance (a , (list , tuple )):
39
- if len (a ) != len (b ):
40
- raise ValueError (f"Length mismatch: { len (a )} vs { len (b )} " )
41
- return all (allclose (x , y ) for x , y in zip (a , b ))
42
- return a == b
63
+ except Exception :
64
+ return False
43
65
44
66
45
- def eval_correctness_test (op , impl , test ):
46
- """Evaluate impl of op against test."""
67
+ def eval_correctness_test (
68
+ op , impl , test
69
+ ) -> Tuple [bool , Optional [str ], Optional [float ], Optional [float ]]:
70
+ """Evaluate impl of op against test.
71
+
72
+ Returns:
73
+ Tuple of (is_correct, error_message, absolute_error, relative_error)
74
+ """
47
75
args , kwargs = test .args , test .kwargs
48
76
ref = op (* args , ** kwargs )
49
77
try :
50
78
res = impl (* args , ** kwargs )
51
- return allclose (ref , res )
79
+ is_correct = allclose (ref , res )
80
+
81
+ # Compute errors even if test passes (for verbose mode)
82
+ abs_error , rel_error = compute_errors (ref , res )
83
+
84
+ return is_correct , None , abs_error , rel_error
52
85
except Exception as e :
53
- logger .warning (format_exception (e , op , args , kwargs ))
54
- return False
86
+ error_msg = format_exception (e , op , args , kwargs )
87
+ logger .warning (error_msg )
88
+ return False , str (e ), None , None
55
89
56
90
57
- def eval_correctness (op , impl , tests ):
91
+ def eval_correctness (op , impl , tests , test_data : defaultdict = defaultdict (dict )):
92
+ """Evaluate correctness of impl against tests."""
58
93
correct , total = 0 , 0
59
94
for test in tests :
60
- logging .debug (f"Testing { op .__name__ } with args { serialize_args (test .args , test .kwargs )} " )
61
- if eval_correctness_test (op , impl , test ):
95
+ args_str = serialize_args (test .args , test .kwargs )
96
+ logging .debug (f"Testing { op .__name__ } with args { args_str } " )
97
+ is_correct , error_msg , abs_error , rel_error = eval_correctness_test (op , impl , test )
98
+
99
+ test_data [args_str ] = {
100
+ "correctness_score" : 1 if is_correct else 0 ,
101
+ "correctness_errors" : error_msg or "" ,
102
+ "absolute_error" : str (abs_error ) if abs_error is not None else "" ,
103
+ "relative_error" : str (rel_error ) if rel_error is not None else "" ,
104
+ }
105
+
106
+ if is_correct :
62
107
correct += 1
63
108
total += 1
64
109
@@ -83,34 +128,80 @@ def cpu_bench(fn, num_runs=100):
83
128
return (time .perf_counter () - start ) / num_runs
84
129
85
130
86
- def eval_performance (op , impl , tests ):
131
+ def eval_performance (op , impl , tests , test_data : defaultdict = defaultdict (dict )):
132
+ """Evaluate performance of impl against tests."""
87
133
bench_fn = (
88
134
triton .testing .do_bench if TRITON_AVAILABLE and torch .cuda .is_available () else cpu_bench
89
135
)
90
136
base_times = []
91
137
test_times = []
138
+ args_strs = []
139
+
92
140
for test in tests :
93
- logging .debug (
94
- f"Benchmarking { op .__name__ } with args { serialize_args (test .args , test .kwargs )} "
95
- )
96
- base_times .append (bench_fn (lambda : op (* test .args , ** test .kwargs )))
141
+ args_str = serialize_args (test .args , test .kwargs )
142
+ args_strs .append (args_str )
143
+ logging .debug (f"Benchmarking { op .__name__ } with args { args_str } " )
144
+ base_time = bench_fn (lambda : op (* test .args , ** test .kwargs ))
145
+ base_times .append (base_time )
146
+ test_time = base_time
97
147
try :
98
- allclose (op (* test .args , ** test .kwargs ), impl (* test .args , ** test .kwargs ))
148
+ ref = op (* test .args , ** test .kwargs )
149
+ res = impl (* test .args , ** test .kwargs )
150
+ if not allclose (
151
+ ref ,
152
+ res ,
153
+ ):
154
+ raise ValueError (f"Reference and result tensors are not close: { ref } vs { res } " )
155
+ test_time = bench_fn (lambda : impl (* test .args , ** test .kwargs ))
99
156
except Exception :
100
- test_times .append (base_times [- 1 ])
101
- continue
102
- test_times .append (bench_fn (lambda : impl (* test .args , ** test .kwargs )))
157
+ pass
158
+ finally :
159
+ test_times .append (test_time )
160
+ test_data [args_str ]["benchmark_time" ] = str (test_time )
161
+
103
162
speedups = torch .tensor (base_times ) / torch .tensor (test_times )
163
+
164
+ # Update test_data with speedups from the tensor
165
+ for i , args_str in enumerate (args_strs ):
166
+ test_data [args_str ]["speedup" ] = str (speedups [i ].item ())
167
+
104
168
return speedups .log ().mean ().exp ()
105
169
106
170
107
171
def eval_one_op (op , impl , correctness_tests , performance_tests ):
108
- """Evaluate impl of op against correctness_tests and performance_tests."""
109
- # TODO: We should have proper error reporting instead of just saying this is 0,
110
- # but that should be a separate PR.
172
+ """Evaluate impl of op against correctness_tests and performance_tests.
173
+
174
+ Returns:
175
+ Tuple of (correctness_score, performance_score, test_data)
176
+ """
177
+ test_data = defaultdict (dict )
178
+
111
179
if uses_cuda_stream (impl ):
112
180
logger .warning (f"Skipping { op .__name__ } because it uses CUDA stream" )
113
- return 0.0 , 1.0
114
- return eval_correctness (op , impl , correctness_tests ), eval_performance (
115
- op , impl , performance_tests
116
- )
181
+ for test in correctness_tests + performance_tests :
182
+ args_str = serialize_args (test .args , test .kwargs )
183
+ test_data [args_str ] = {
184
+ "correctness_score" : 0 ,
185
+ "benchmark_time" : "" ,
186
+ "speedup" : "" ,
187
+ "correctness_errors" : "Skipped: uses CUDA stream" ,
188
+ "absolute_error" : "" ,
189
+ "relative_error" : "" ,
190
+ }
191
+ return 0 , 1.0 , test_data
192
+
193
+ correctness_score = eval_correctness (op , impl , correctness_tests , test_data )
194
+ performance_score = eval_performance (op , impl , performance_tests , test_data )
195
+ test_data = dict (test_data )
196
+ return correctness_score , performance_score , test_data
197
+
198
+
199
+ def save_verbose_results (
200
+ results : List [Dict [str , Any ]],
201
+ output_path : str = "backendbench_verbose_results.json" ,
202
+ ):
203
+ """Save verbose results to a JSON file."""
204
+ with open (Path (output_path ), "w" ) as f :
205
+ json .dump (results , f , indent = 2 )
206
+
207
+ logger .info (f"Verbose results saved to { output_path } " )
0 commit comments