@@ -160,43 +160,6 @@ def extract_kernels(funcs):
160160 raise NotImplementedError (f"BENCHMARKING_METHOD: { BENCHMARKING_METHOD } isn't implemented" )
161161
162162
163- def assert_close (x , y , atol = None , rtol = None , err_msg = "" ):
164- import numpy as np
165- import torch
166-
167- # canonicalize arguments to be tensors
168- if not isinstance (x , torch .Tensor ):
169- x = torch .tensor (x )
170- if not isinstance (y , torch .Tensor ):
171- y = torch .tensor (y )
172- # absolute tolerance
173- if atol is None :
174- atol = 1e-2
175- atol = atol (x .dtype ) if callable (atol ) else atol
176- # relative tolerance hook
177- if rtol is None :
178- rtol = 0.
179- rtol = rtol (x .dtype ) if callable (rtol ) else rtol
180- # we use numpy instead of pytorch
181- # as it seems more memory efficient
182- # pytorch tends to oom on large tensors
183- if isinstance (x , torch .Tensor ):
184- if x .dtype == torch .bfloat16 :
185- x = x .float ()
186- x = x .cpu ().detach ().numpy ()
187- if isinstance (y , torch .Tensor ):
188- if y .dtype == torch .bfloat16 :
189- y = y .float ()
190- y = y .cpu ().detach ().numpy ()
191- # we handle size==1 case separately as we can
192- # provide better error message there
193- if x .size > 1 or y .size > 1 :
194- np .testing .assert_allclose (x , y , atol = atol , rtol = rtol , equal_nan = True )
195- return
196- if not np .allclose (x , y , atol = atol , rtol = rtol ):
197- raise AssertionError (f"{ err_msg } { x } is not close to { y } (atol={ atol } , rtol={ rtol } )" )
198-
199-
200163def perf_report (benchmarks ):
201164 """
202165 Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
0 commit comments