11# Owner(s): ["module: intel"]
22
3+ import contextlib
4+ import functools
5+ import inspect
36import itertools
47import math
58import random
2326)
2427
2528
29+ @contextlib .contextmanager
30+ def tf32_off ():
31+ enabled = torch .backends .mkldnn .enabled
32+ deterministic = torch .backends .mkldnn .deterministic
33+ with torch .backends .mkldnn .flags (
34+ enabled = enabled , deterministic = deterministic , allow_tf32 = False
35+ ):
36+ yield
37+
38+
39+ @contextlib .contextmanager
40+ def tf32_on (self , tf32_precision = 1e-5 ):
41+ enabled = torch .backends .mkldnn .enabled
42+ deterministic = torch .backends .mkldnn .deterministic
43+ old_precision = self .precision
44+ try :
45+ self .precision = tf32_precision
46+ with torch .backends .mkldnn .flags (
47+ enabled = enabled , deterministic = deterministic , allow_tf32 = True
48+ ):
49+ yield
50+ finally :
51+ self .precision = old_precision
52+
53+
54+ # This is a wrapper that wraps a test to run this test twice, one with
55+ # allow_tf32=True, another with allow_tf32=False. When running with
56+ # allow_tf32=True, it will use reduced precision as specified by the
57+ # argument. For example:
58+ # @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
59+ # @tf32_on_and_off(0.005)
60+ # def test_matmul(self, device, dtype):
61+ # a = ...; b = ...;
62+ # c = torch.matmul(a, b)
63+ # self.assertEqual(c, expected)
64+ # In the above example, when testing torch.float32 , the matmul will be running at
65+ # TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
66+ # precision to check values.
67+ #
68+ # This decorator can be used for function with or without device/dtype, such as
69+ # @tf32_on_and_off(0.005)
70+ # def test_my_op(self)
71+ # @tf32_on_and_off(0.005)
72+ # def test_my_op(self, device)
73+ # @tf32_on_and_off(0.005)
74+ # def test_my_op(self, device, dtype)
75+ # @tf32_on_and_off(0.005)
76+ # def test_my_op(self, dtype)
77+ def tf32_on_and_off (tf32_precision = 1e-5 ):
78+ def with_tf32_disabled (self , function_call ):
79+ with tf32_off ():
80+ function_call ()
81+
82+ def with_tf32_enabled (self , function_call ):
83+ with tf32_on (self , tf32_precision ):
84+ function_call ()
85+
86+ def wrapper (f ):
87+ params = inspect .signature (f ).parameters
88+ arg_names = tuple (params .keys ())
89+
90+ @functools .wraps (f )
91+ def wrapped (* args , ** kwargs ):
92+ kwargs .update (zip (arg_names , args ))
93+ cond = True
94+ if "device" in kwargs :
95+ cond = cond and (torch .device (kwargs ["device" ]).type == "xpu" )
96+ if "dtype" in kwargs :
97+ cond = cond and (
98+ kwargs ["dtype" ] in {torch .float32 }
99+ ) # TODO: add complex64
100+ if cond :
101+ with_tf32_disabled (kwargs ["self" ], lambda : f (** kwargs ))
102+ with_tf32_enabled (kwargs ["self" ], lambda : f (** kwargs ))
103+ else :
104+ f (** kwargs )
105+
106+ return wrapped
107+
108+ return wrapper
109+
110+
111+ # This is a wrapper that wraps a test to run it with TF32 turned off.
112+ # This wrapper is designed to be used when a test uses matmul or convolutions
113+ # but the purpose of that test is not testing matmul or convolutions.
114+ # Disabling TF32 will enforce torch.float tensors to be always computed
115+ # at full precision.
116+ def with_tf32_off (f ):
117+ @functools .wraps (f )
118+ def wrapped (* args , ** kwargs ):
119+ with tf32_off ():
120+ return f (* args , ** kwargs )
121+
122+ return wrapped
123+
124+
26125class TestBasicGEMM (TestCase ):
27126 def _test_addmm_addmv (
28127 self , f , t , m , v , * , alpha = None , beta = None , transpose_out = False , activation = None
@@ -133,11 +232,13 @@ def maybe_transpose(cond, m):
133232
134233 @precisionOverride ({torch .float : 1e-4 , torch .double : 1e-6 , torch .half : 1e-1 })
135234 @dtypes (torch .float32 , torch .half , torch .double )
235+ @tf32_on_and_off (0.05 )
136236 def test_addmm (self , device , dtype ):
137237 self ._test_addmm_impl (torch .addmm , None , device , dtype )
138238
139239 @precisionOverride ({torch .bfloat16 : 1e-0 , torch .half : 1e-3 , torch .float : 1e-4 })
140240 @dtypes (torch .bfloat16 , torch .half , torch .float , torch .double )
241+ @tf32_on_and_off (0.005 )
141242 def test_addmv (self , device , dtype ):
142243 # have to use torch.randn(...).to(bfloat16) instead of
143244 # torch.randn(..., dtype=bfloat16). randn does not support
@@ -185,6 +286,7 @@ def test_addmv(self, device, dtype):
185286 torch .float32 ,
186287 torch .float64 ,
187288 )
289+ @tf32_on_and_off (0.05 )
188290 def test_mm (self , device , dtype ):
189291 def _test_mm (n , m , p , dtype , genf ):
190292 # helper function
@@ -287,6 +389,7 @@ def genf_Half(x, y):
287389
288390 @precisionOverride ({torch .half : 0.05 , torch .bfloat16 : 0.05 })
289391 @dtypes (torch .float32 , torch .bfloat16 , torch .half , torch .float64 )
392+ @tf32_on_and_off (0.05 )
290393 def test_bmm (self , device , dtype ):
291394 batch_sizes = [1 , 10 ]
292395 M , N , O = 23 , 15 , 12
@@ -403,6 +506,7 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
403506
404507 @precisionOverride ({torch .half : 0.05 , torch .bfloat16 : 0.05 })
405508 @dtypes (torch .float64 , torch .float32 , torch .bfloat16 , torch .half )
509+ @tf32_on_and_off (0.005 )
406510 def test_addbmm (self , device , dtype ):
407511 num_batches = 2
408512 M , N , O = 16 , 17 , 18
@@ -506,6 +610,7 @@ def generate_tensor():
506610
507611 @precisionOverride ({torch .half : 0.1 , torch .bfloat16 : 0.5 , torch .float64 : 1e-6 })
508612 @dtypes (torch .float64 , torch .float32 , torch .bfloat16 , torch .half )
613+ @tf32_on_and_off (0.01 )
509614 def test_baddbmm (self , device , dtype ):
510615 num_batches = 10
511616 M , N , O = 12 , 8 , 50
@@ -568,6 +673,7 @@ def generate_tensor():
568673 for b1 , b2 , ref , out_tensor in generate_tensor ():
569674 self ._test_addbmm_baddbmm ("baddbmm" , b1 , b2 , ref , out_tensor )
570675
676+ @tf32_on_and_off (0.05 )
571677 def test_tensordot (self , device ):
572678 a = torch .arange (60.0 , device = device ).reshape (3 , 4 , 5 )
573679 b = torch .arange (24.0 , device = device ).reshape (4 , 3 , 2 )
@@ -604,6 +710,7 @@ def test_tensordot(self, device):
604710
605711 @dtypes (torch .float , torch .double )
606712 @precisionOverride ({torch .float32 : 1e-4 })
713+ @tf32_on_and_off (0.005 )
607714 def test_1_sized_with_0_strided (self , device , dtype ):
608715 a = make_tensor ((8 , 1 , 64 ), dtype = dtype , device = device )
609716 a_strided = torch .as_strided (a , size = [8 , 1 , 64 ], stride = [64 , 0 , 1 ])
@@ -646,6 +753,7 @@ def _select_broadcastable_dims(self, dims_full=None):
646753 dims_small = [ds ] + dims_small
647754 return (dims_small , dims_large , dims_full )
648755
756+ @tf32_on_and_off (0.005 )
649757 def test_broadcast_fused_matmul (self , device ):
650758 fns = ["baddbmm" , "addbmm" , "addmm" , "addmv" , "addr" ]
651759
@@ -692,6 +800,7 @@ def dims_full_for_fn():
692800 self .assertEqual (r0 , r1 )
693801
694802 @dtypes (torch .float32 , torch .float64 )
803+ @tf32_on_and_off (0.005 )
695804 def test_strided_mm_bmm (self , device , dtype ):
696805 # Tests strided view case with stride smaller than corresponding dimension size
697806 x = torch .tensor ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]], dtype = dtype , device = device )
@@ -706,6 +815,7 @@ def test_strided_mm_bmm(self, device, dtype):
706815 torch_fn = lambda x : torch .mm (x , x ) # noqa: E731
707816 self .compare_with_numpy (torch_fn , np_fn , sx [0 ])
708817
818+ @tf32_on_and_off (0.005 )
709819 def test_mm_empty_inputs_mixed_dtype_errors (self , device ):
710820 a = torch .randint (0 , 10 , [1 , 10 ], dtype = torch .int16 , device = device )
711821 b = torch .randn (10 , 20 , dtype = torch .float32 , device = device )
@@ -714,6 +824,7 @@ def test_mm_empty_inputs_mixed_dtype_errors(self, device):
714824 ):
715825 torch .mm (a , b )
716826
827+ @tf32_on_and_off (0.005 )
717828 def test_matmul_45724 (self , device ):
718829 # https://github.com/pytorch/pytorch/issues/45724
719830 a = torch .rand (65537 , 22 , 64 , device = device , dtype = torch .half )
@@ -731,6 +842,7 @@ def test_matmul_45724(self, device):
731842 torch .float32 ,
732843 torch .float64 ,
733844 )
845+ @tf32_on_and_off (0.005 )
734846 def test_baddbmm_input_dtypes_compatibility (self , device , dtype ):
735847 batch1 = torch .rand ((1 , 2 , 2 ), dtype = torch .float32 , device = device )
736848 batch2 = torch .rand ((1 , 2 , 2 ), dtype = torch .float32 , device = device )
@@ -745,6 +857,7 @@ def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
745857 self .assertEqual (out , y_ref )
746858
747859 @dtypes (torch .float )
860+ @tf32_on_and_off (0.005 )
748861 def test_baddbmm_nan_input_with_zero_beta (self , device , dtype ):
749862 for shape in [[3 , 2 , 2 ], [2 , 20 , 20 ]]:
750863 mat1 , mat2 = (
@@ -767,6 +880,7 @@ def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
767880
768881 @precisionOverride ({torch .double : 1e-6 })
769882 @dtypes (torch .float , torch .double )
883+ @tf32_on_and_off (0.005 )
770884 def test_addmm_sizes (self , device , dtype ):
771885 for m in [0 , 1 , 25 ]:
772886 for n in [0 , 1 , 10 ]:
@@ -798,6 +912,7 @@ def test_addmm_sizes(self, device, dtype):
798912 }
799913 )
800914 @dtypes (torch .double , torch .float32 , torch .bfloat16 , torch .half )
915+ @tf32_on_and_off (0.05 )
801916 def test_addmm_gelu (self , device , dtype ):
802917 self ._test_addmm_impl (torch ._addmm_activation , "gelu" , device , dtype )
803918
@@ -812,10 +927,12 @@ def test_addmm_gelu(self, device, dtype):
812927 }
813928 )
814929 @dtypes (torch .double , torch .float32 , torch .bfloat16 , torch .half )
930+ @tf32_on_and_off (0.05 )
815931 def test_addmm_relu (self , device , dtype ):
816932 self ._test_addmm_impl (torch ._addmm_activation , "relu" , device , dtype )
817933
818- @dtypes (torch .float , torch .bfloat16 , torch .half , torch .double )
934+ @dtypes (torch .float , torch .bfloat16 , torch .half )
935+ @tf32_on_and_off (0.005 )
819936 def test_addmv_rowmajor_colmajor_incx_incy_lda (self , device , dtype ):
820937 # tests (o, s)*(s). o is output size, s is summed size.
821938 o = 5
@@ -859,6 +976,7 @@ def _test(row_major, incx, incy, lda_tail):
859976 }
860977 )
861978 @dtypes (torch .double , torch .bfloat16 , torch .half , torch .float32 )
979+ @tf32_on_and_off (0.005 )
862980 def test_corner_cases_of_cublasltmatmul (self , device , dtype ):
863981 # common case
864982 M = torch .randn (128 , device = device ).to (dtype )
@@ -998,6 +1116,7 @@ def call_torch_fn(*args, **kwargs):
9981116 torch .tensor (0.0 , device = device ), fn (torch .dot , (0 ,), (0 ,), test_out = True )
9991117 )
10001118
1119+ @tf32_on_and_off (0.005 )
10011120 def test_large_bmm_backward (self , device ):
10021121 A = torch .randn ([1024 , 2 , 1024 ], device = device ).mT .contiguous ().mT
10031122 B = torch .randn ([1 , 1024 , 65536 ], device = device , requires_grad = True )
@@ -1006,6 +1125,7 @@ def test_large_bmm_backward(self, device):
10061125 # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
10071126 (A @ B ).backward (G )
10081127
1128+ @tf32_on_and_off (0.005 )
10091129 def test_large_bmm_mm_backward (self , device ):
10101130 A = torch .randn ([1024 , 2 , 1024 ], device = device ).mT .contiguous ().mT
10111131 B = torch .randn ([1024 , 65536 ], device = device , requires_grad = True )
@@ -1104,6 +1224,7 @@ def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
11041224 self .check_single_matmul (x , y )
11051225
11061226 @dtypes (torch .float )
1227+ @tf32_on_and_off (0.005 )
11071228 def test_matmul_out_kernel_errors_with_autograd (self , device , dtype ):
11081229 a = torch .empty (
11091230 (256 , 512 ), device = device , dtype = dtype , requires_grad = True
0 commit comments