11# Owner(s): ["module: inductor"]
22
33import logging
4+ import unittest
45
56import torch
67import torch ._inductor
1112from torch .testing import FileCheck
1213from torch .testing ._internal .common_utils import (
1314 instantiate_parametrized_tests ,
15+ patch_test_members ,
16+ is_navi3_arch ,
1417 parametrize ,
15- skipIfXpu ,
18+ TEST_XPU ,
1619)
1720from torch .testing ._internal .inductor_utils import GPU_TYPE , HAS_CUDA
1821from torch .testing ._internal .triton_utils import requires_gpu
@@ -48,9 +51,10 @@ def forward(self, input1, input2):
4851
4952
5053@requires_gpu
51- @skipIfXpu (
52- msg = "Intel GPU has not enabled decompose_mem_bound_mm PASS in "
53- "torch/_inductor/fx_passes/decompose_mem_bound_mm.py"
54+ @unittest .skipIf (
55+ TEST_XPU ,
56+ "Intel GPU has not enabled decompose_mem_bound_mm PASS in "
57+ "torch/_inductor/fx_passes/decompose_mem_bound_mm.py" ,
5458)
5559@torch ._inductor .config .patch (
5660 post_grad_fusion_options = {
@@ -59,31 +63,46 @@ def forward(self, input1, input2):
5963)
6064@instantiate_parametrized_tests
6165class TestDecomposeMemMM (TestCase ):
62- def compare_dict_tensors (self , ref_dict , res_dict , rtol = 1e-3 , atol = 1e-3 ):
66+ def __init__ (self , method_name = 'runTest' , methodName = 'runTest' ):
67+ super ().__init__ (method_name , methodName )
68+ self .atol = 1e-3
69+ self .rtol = 1e-3
70+
71+ def setup_tolerance (self , rtol = None , atol = None ):
72+ if rtol is None :
73+ rtol = self .rtol
74+ if atol is None :
75+ atol = self .rtol
76+
77+ def compare_dict_tensors (self , ref_dict , res_dict , rtol = None , atol = None ):
78+ self .setup_tolerance (rtol , atol )
6379 if len (set (ref_dict .keys ())) != len (set (res_dict .keys ())):
6480 return False
6581 for key1 in ref_dict .keys ():
6682 key2 = "_orig_mod." + key1
6783 assert key2 in res_dict , f"{ key1 } does not exist in traced module"
68- if not torch .allclose (ref_dict [key1 ], res_dict [key2 ], rtol = rtol , atol = atol ):
84+ if not torch .allclose (ref_dict [key1 ], res_dict [key2 ], rtol = self . rtol , atol = self . atol ):
6985 return False
7086 return True
7187
72- def compare_pred (self , module , traced , input , rtol = 1e-3 , atol = 1e-3 ):
88+ def compare_pred (self , module , traced , input , rtol = None , atol = None ):
89+ self .setup_tolerance (rtol , atol )
7390 ref = module (* input )
7491 res = traced (* input )
75- self .assertEqual (ref , res , rtol = rtol , atol = atol )
92+ self .assertEqual (ref , res , rtol = self . rtol , atol = self . atol )
7693
77- def compare_parameters (self , module , traced , rtol = 1e-3 , atol = 1e-3 ):
94+ def compare_parameters (self , module , traced , rtol = None , atol = None ):
95+ self .setup_tolerance (rtol , atol )
7896 ref_params = dict (module .named_parameters ())
7997 res_params = dict (traced .named_parameters ())
80- self .assertTrue (self .compare_dict_tensors (ref_params , res_params , rtol , atol ))
98+ self .assertTrue (self .compare_dict_tensors (ref_params , res_params , rtol = self . rtol , atol = self . atol ))
8199
82- def compare_gradients (self , module , traced , rtol = 1e-3 , atol = 1e-3 ):
100+ def compare_gradients (self , module , traced , rtol = None , atol = None ):
101+ self .setup_tolerance (rtol , atol )
83102 ref_grad = {key : param .grad for key , param in module .named_parameters ()}
84103 res_grad = {key : param .grad for key , param in traced .named_parameters ()}
85104 self .assertTrue (
86- self .compare_dict_tensors (ref_grad , res_grad , rtol = rtol , atol = atol )
105+ self .compare_dict_tensors (ref_grad , res_grad , rtol = self . rtol , atol = self . atol )
87106 )
88107
89108 @parametrize (
@@ -190,6 +209,12 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
190209 )
191210 counters .clear ()
192211
212+ # We have to increase tolerance for navi3 because all fp16, bf16
213+ # GEMMs operations have an accuracy issue caused by hardware limitation
214+ @patch_test_members ({
215+ "atol" : 2e-3 if is_navi3_arch () else 1e-3 ,
216+ "rtol" : 2e-3 if is_navi3_arch () else 1e-3
217+ })
193218 @parametrize (
194219 "m,k,n, should_decompose" ,
195220 [(20480 , 5 , 2 , True ), (20480 , 32 , 2 , False ), (2048 , 2 , 2 , False )],
@@ -298,6 +323,12 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose):
298323 )
299324 counters .clear ()
300325
326+ # We have to increase tolerance for navi3 because all fp16, bf16
327+ # GEMMs operations have an accuracy issue caused by hardware limitation
328+ @patch_test_members ({
329+ "atol" : 3e-3 if is_navi3_arch () else 1e-3 ,
330+ "rtol" : 4e-3 if is_navi3_arch () else 1e-3
331+ })
301332 @parametrize (
302333 "m,k,n, should_decompose" ,
303334 [(20480 , 5 , 2 , True ), (20480 , 32 , 2 , False ), (2048 , 2 , 2 , False )],
0 commit comments