Skip to content

Commit 03c7da0

Browse files
iupaikov-amdpruthvistony
authored andcommitted
test_decompose_mem_bound_mm.py tolerance increase for navi3x
1 parent 2c220b2 commit 03c7da0

File tree

2 files changed

+78
-14
lines changed

2 files changed

+78
-14
lines changed

test/inductor/test_decompose_mem_bound_mm.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Owner(s): ["module: inductor"]
22

33
import logging
4+
import unittest
45

56
import torch
67
import torch._inductor
@@ -11,8 +12,10 @@
1112
from torch.testing import FileCheck
1213
from torch.testing._internal.common_utils import (
1314
instantiate_parametrized_tests,
15+
patch_test_members,
16+
is_navi3_arch,
1417
parametrize,
15-
skipIfXpu,
18+
TEST_XPU,
1619
)
1720
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
1821
from torch.testing._internal.triton_utils import requires_gpu
@@ -48,9 +51,10 @@ def forward(self, input1, input2):
4851

4952

5053
@requires_gpu
51-
@skipIfXpu(
52-
msg="Intel GPU has not enabled decompose_mem_bound_mm PASS in "
53-
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py"
54+
@unittest.skipIf(
55+
TEST_XPU,
56+
"Intel GPU has not enabled decompose_mem_bound_mm PASS in "
57+
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py",
5458
)
5559
@torch._inductor.config.patch(
5660
post_grad_fusion_options={
@@ -59,31 +63,46 @@ def forward(self, input1, input2):
5963
)
6064
@instantiate_parametrized_tests
6165
class TestDecomposeMemMM(TestCase):
62-
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
66+
def __init__(self, method_name='runTest', methodName='runTest'):
67+
super().__init__(method_name, methodName)
68+
self.atol = 1e-3
69+
self.rtol = 1e-3
70+
71+
def setup_tolerance(self, rtol=None, atol=None):
72+
if rtol is None:
73+
rtol = self.rtol
74+
if atol is None:
75+
atol = self.rtol
76+
77+
def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None):
78+
self.setup_tolerance(rtol, atol)
6379
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
6480
return False
6581
for key1 in ref_dict.keys():
6682
key2 = "_orig_mod." + key1
6783
assert key2 in res_dict, f"{key1} does not exist in traced module"
68-
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
84+
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=self.rtol, atol=self.atol):
6985
return False
7086
return True
7187

72-
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
88+
def compare_pred(self, module, traced, input, rtol=None, atol=None):
89+
self.setup_tolerance(rtol, atol)
7390
ref = module(*input)
7491
res = traced(*input)
75-
self.assertEqual(ref, res, rtol=rtol, atol=atol)
92+
self.assertEqual(ref, res, rtol=self.rtol, atol=self.atol)
7693

77-
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
94+
def compare_parameters(self, module, traced, rtol=None, atol=None):
95+
self.setup_tolerance(rtol, atol)
7896
ref_params = dict(module.named_parameters())
7997
res_params = dict(traced.named_parameters())
80-
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
98+
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol=self.rtol, atol=self.atol))
8199

82-
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
100+
def compare_gradients(self, module, traced, rtol=None, atol=None):
101+
self.setup_tolerance(rtol, atol)
83102
ref_grad = {key: param.grad for key, param in module.named_parameters()}
84103
res_grad = {key: param.grad for key, param in traced.named_parameters()}
85104
self.assertTrue(
86-
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
105+
self.compare_dict_tensors(ref_grad, res_grad, rtol=self.rtol, atol=self.atol)
87106
)
88107

89108
@parametrize(
@@ -190,6 +209,12 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
190209
)
191210
counters.clear()
192211

212+
# We have to increase tolerance for navi3 because all fp16, bf16
213+
# GEMMs operations have an accuracy issue caused by hardware limitation
214+
@patch_test_members({
215+
"atol": 2e-3 if is_navi3_arch() else 1e-3,
216+
"rtol": 2e-3 if is_navi3_arch() else 1e-3
217+
})
193218
@parametrize(
194219
"m,k,n, should_decompose",
195220
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
@@ -298,6 +323,12 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose):
298323
)
299324
counters.clear()
300325

326+
# We have to increase tolerance for navi3 because all fp16, bf16
327+
# GEMMs operations have an accuracy issue caused by hardware limitation
328+
@patch_test_members({
329+
"atol": 3e-3 if is_navi3_arch() else 1e-3,
330+
"rtol": 4e-3 if is_navi3_arch() else 1e-3
331+
})
301332
@parametrize(
302333
"m,k,n, should_decompose",
303334
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],

torch/testing/_internal/common_utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,18 @@
102102
has_pytest = False
103103

104104

105-
MI300_ARCH = ("gfx942",)
106-
105+
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
106+
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
107+
NAVI3_ARCH = ("gfx1100", "gfx1101")
108+
NAVI4_ARCH = ("gfx1200", "gfx1201")
109+
110+
def is_navi3_arch():
111+
if torch.cuda.is_available():
112+
prop = torch.cuda.get_device_properties(0)
113+
gfx_arch = prop.gcnArchName.split(":")[0]
114+
if gfx_arch in NAVI3_ARCH:
115+
return True
116+
return False
107117

108118
def freeze_rng_state(*args, **kwargs):
109119
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
@@ -5646,3 +5656,26 @@ def load_inline(*args, **kwargs):
56465656
return func(*args, load_inline=load_inline, **kwargs)
56475657

56485658
return wrapper
5659+
5660+
# Decorator to patch multiple test class members for the duration of the subtest
5661+
def patch_test_members(updates: Dict[str, Any]):
5662+
def decorator(test_func):
5663+
@wraps(test_func)
5664+
def wrapper(self, *args, **kwargs):
5665+
# Store the original values of the specified members
5666+
original_values = {member: getattr(self, member) for member in updates}
5667+
5668+
# Update the members before running the subtest
5669+
for member, value in updates.items():
5670+
setattr(self, member, value)
5671+
5672+
# Run the test function, allowing subtests to run
5673+
try:
5674+
return test_func(self, *args, **kwargs)
5675+
finally:
5676+
# Restore the original values of the specified members after the subtest finishes
5677+
for member, original_value in original_values.items():
5678+
setattr(self, member, original_value)
5679+
5680+
return wrapper
5681+
return decorator

0 commit comments

Comments
 (0)