Skip to content

Commit 7e51233

Browse files
fix test_linalg and test_torch issue with bf32_on_and_off updates (#1884)
Fix test_torch and test_linalg issues introduced by pytorch/pytorch@f4d8bc4#diff-7e17421f32124016eb8de04dc2f445da5786a28355e1addc72b305466f590180. --------- Co-authored-by: Zhong Ruijie <[email protected]>
1 parent 6e44729 commit 7e51233

File tree

2 files changed

+99
-11
lines changed

2 files changed

+99
-11
lines changed

test/xpu/test_linalg_xpu.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: intel"]
22

3+
import contextlib
34
import itertools
45
import math
56
import unittest
@@ -13,8 +14,11 @@
1314
instantiate_device_type_tests,
1415
precisionOverride,
1516
)
16-
from torch.testing._internal.common_dtype import floating_and_complex_types_and
17-
from torch.testing._internal.common_mkldnn import bf32_on_and_off
17+
from torch.testing._internal.common_dtype import (
18+
floating_and_complex_types_and,
19+
floating_types_and,
20+
)
21+
from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off
1822
from torch.testing._internal.common_utils import (
1923
IS_WINDOWS,
2024
parametrize,
@@ -98,7 +102,7 @@ def preferred_linalg_library(self):
98102
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
99103
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
100104
@tf32_on_and_off(0.05)
101-
@bf32_on_and_off(0.05)
105+
@reduced_f32_on_and_off(0.05)
102106
def addbmm(self, device, dtype):
103107
num_batches = 2
104108
M, N, O = 16, 17, 18
@@ -392,6 +396,83 @@ def ck_blas_library(self):
392396
pass
393397

394398

399+
@precisionOverride(
400+
{
401+
torch.double: 1e-8,
402+
torch.float: 1e-4,
403+
torch.bfloat16: 5e-2,
404+
torch.half: 5e-2,
405+
torch.cfloat: 1e-4,
406+
torch.cdouble: 1e-8,
407+
}
408+
)
409+
@dtypes(*floating_types_and(torch.bfloat16, torch.half))
410+
@tf32_on_and_off(0.05)
411+
@reduced_f32_on_and_off(0.05)
412+
def addmm_relu_tunableop_rocm(self, device, dtype):
413+
with self._tunableop_ctx():
414+
torch.xpu.tunable.set_rotating_buffer_size(0)
415+
torch.xpu.tunable.set_max_tuning_iterations(1)
416+
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
417+
418+
419+
def get_tunableop_untuned_filename():
420+
import os
421+
422+
ordinal = torch.xpu.current_device()
423+
untuned_filename_env = os.getenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME")
424+
untuned_filename_base, _, _ = untuned_filename_env.rpartition(".")
425+
untuned_filename = f"{untuned_filename_base}{ordinal}.csv"
426+
return untuned_filename
427+
428+
429+
@contextlib.contextmanager
430+
def __tunableop_ctx(self):
431+
# Initialize and then tear down TunableOp
432+
import glob
433+
import os
434+
435+
self._set_tunableop_defaults()
436+
torch.xpu.tunable.enable(True)
437+
438+
try:
439+
yield
440+
finally:
441+
# disables TunableOp
442+
torch.xpu.tunable.enable(False)
443+
444+
# clean up, remove any files that were generated
445+
results_filename = torch.xpu.tunable.get_filename()
446+
results_filename_pattern, _, _ = results_filename.rpartition(".")
447+
untuned_filename = get_tunableop_untuned_filename()
448+
untuned_filename_pattern, _, _ = untuned_filename.rpartition(".")
449+
patterns = [
450+
f"{results_filename_pattern[:-1]}*.csv",
451+
f"{untuned_filename_pattern[:-1]}*.csv",
452+
]
453+
files = [f for pattern in patterns for f in glob.glob(pattern)]
454+
for file in files:
455+
try:
456+
os.remove(file)
457+
# NB: The file is locked on Windows
458+
except (FileNotFoundError, PermissionError):
459+
pass
460+
461+
# undo all the environment variables set
462+
# loop through a list of potentially used
463+
# environment variables.
464+
env_list = [
465+
"PYTORCH_TUNABLEOP_BLAS_LOG",
466+
"PYTORCH_TUNABLEOP_NUMERICAL_CHECK",
467+
"PYTORCH_TUNABLEOP_UNTUNED_FILENAME",
468+
]
469+
for env in env_list:
470+
try:
471+
del os.environ[env]
472+
except KeyError:
473+
pass
474+
475+
395476
with XPUPatchForImport(False):
396477
from test_linalg import TestLinalg
397478

@@ -410,6 +491,8 @@ def ck_blas_library(self):
410491
TestLinalg.test_matmul_small_brute_force_2d_Nd = matmul_small_brute_force_2d_Nd
411492
TestLinalg.test_matmul_small_brute_force_3d_Nd = matmul_small_brute_force_3d_Nd
412493
TestLinalg.test_ck_blas_library = ck_blas_library
494+
TestLinalg.test_addmm_relu_tunableop_rocm = addmm_relu_tunableop_rocm
495+
TestLinalg._tunableop_ctx = __tunableop_ctx
413496

414497
TestLinalg._default_dtype_check_enabled = True
415498
instantiate_device_type_tests(TestLinalg, globals(), only_for=("xpu"), allow_xpu=True)

test/xpu/test_torch_xpu.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
get_all_qint_dtypes,
6767
integral_types_and,
6868
)
69-
from torch.testing._internal.common_mkldnn import bf32_on_and_off
69+
from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off
7070
from torch.testing._internal.common_optimizers import (
7171
_get_optim_inputs_including_global_cliquey_kwargs,
7272
optim_db,
@@ -2996,7 +2996,7 @@ def test_cdist_cuda_backward(self, device):
29962996
self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
29972997

29982998
@tf32_on_and_off(0.005)
2999-
@bf32_on_and_off(0.005)
2999+
@reduced_f32_on_and_off(0.08)
30003000
def test_cdist_large(self, device):
30013001
for cm in [
30023002
"use_mm_for_euclid_dist_if_necessary",
@@ -3011,7 +3011,7 @@ def test_cdist_large(self, device):
30113011

30123012
@slowTest
30133013
@tf32_on_and_off(0.01)
3014-
@bf32_on_and_off(0.01)
3014+
@reduced_f32_on_and_off(0.08)
30153015
def test_cdist_large_batch(self, device):
30163016
for cm in [
30173017
"use_mm_for_euclid_dist_if_necessary",
@@ -3025,7 +3025,7 @@ def test_cdist_large_batch(self, device):
30253025
self.assertEqual(expected, actual)
30263026

30273027
@tf32_on_and_off(0.005)
3028-
@bf32_on_and_off(0.005)
3028+
@reduced_f32_on_and_off(0.04)
30293029
def test_cdist_non_contiguous(self, device):
30303030
for cm in ["use_mm_for_euclid_dist", "donot_use_mm_for_euclid_dist"]:
30313031
x = torch.randn(5, 7, device=device).mT
@@ -3053,7 +3053,7 @@ def test_cdist_non_contiguous(self, device):
30533053
self.assertEqual(expected, actual)
30543054

30553055
@tf32_on_and_off(0.005)
3056-
@bf32_on_and_off(0.005)
3056+
@reduced_f32_on_and_off(0.04)
30573057
def test_cdist_non_contiguous_batch(self, device):
30583058
for cm in ["use_mm_for_euclid_dist", "donot_use_mm_for_euclid_dist"]:
30593059
x = torch.randn(4, 3, 2, 5, 7, device=device).mT
@@ -10913,7 +10913,9 @@ def test_manual_seed(self):
1091310913
)
1091410914
self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
1091510915
for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
10916-
with self.assertRaisesRegex(RuntimeError, r"Overflow when unpacking long"):
10916+
with self.assertRaisesRegex(
10917+
ValueError, r"Overflow when unpacking long long"
10918+
):
1091710919
torch.manual_seed(invalid_seed)
1091810920

1091910921
torch.set_rng_state(rng_state)
@@ -12546,9 +12548,12 @@ def test_size_stride(self) -> None:
1254612548
def test_invalid_arg_error_handling(self) -> None:
1254712549
"""Tests that errors from old TH functions are propagated back"""
1254812550
for invalid_val in [-1, 2**65]:
12549-
self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val))
1255012551
self.assertRaises(
12551-
RuntimeError, lambda: torch.set_num_interop_threads(invalid_val)
12552+
(ValueError, RuntimeError), lambda: torch.set_num_threads(invalid_val)
12553+
)
12554+
self.assertRaises(
12555+
(ValueError, RuntimeError),
12556+
lambda: torch.set_num_interop_threads(invalid_val),
1255212557
)
1255312558

1255412559
def _get_tensor_prop(self, t):

0 commit comments

Comments
 (0)