Skip to content

Commit 9f98e37

Browse files
jianyizhpytorchmergebot
authored andcommitted
[Intel GPU] add tf32 support for matmul on XPU (pytorch#144240)
Support xpu tf32 matmul using torch.bachend.mkldnn.allow_tf32, we will discuss in future if we need a new api to control matmul only ~~Support xpu tf32 matmul using torch.set_float32_matmul_precision. For conv, check pytorch#137570 We decide not following torch.backends.cuda.matmul.allow_tf32 because this API actually calls setAllowTF32CuBLAS to set matmul_precison to high. We also avoid other related tf32 changes (i.e. in inductor) by not introducing new API.~~ Pull Request resolved: pytorch#144240 Approved by: https://github.com/EikanWang
1 parent ff039d3 commit 9f98e37

File tree

2 files changed

+128
-2
lines changed

2 files changed

+128
-2
lines changed

aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,12 @@ sycl::event matmul(
194194
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
195195

196196
if (m1_dt == dnnl::memory::data_type::f32) {
197-
pattr.set_fpmath_mode(dnnl::fpmath_mode::strict);
197+
bool allow_tf32 = at::globalContext().allowTF32OneDNN();
198+
if (allow_tf32) {
199+
pattr.set_fpmath_mode(dnnl::fpmath_mode::tf32);
200+
} else {
201+
pattr.set_fpmath_mode(dnnl::fpmath_mode::strict);
202+
}
198203
}
199204

200205
// STEP3: create primitive

test/xpu/test_gemm.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Owner(s): ["module: intel"]
22

3+
import contextlib
4+
import functools
5+
import inspect
36
import itertools
47
import math
58
import random
@@ -23,6 +26,102 @@
2326
)
2427

2528

29+
@contextlib.contextmanager
30+
def tf32_off():
31+
enabled = torch.backends.mkldnn.enabled
32+
deterministic = torch.backends.mkldnn.deterministic
33+
with torch.backends.mkldnn.flags(
34+
enabled=enabled, deterministic=deterministic, allow_tf32=False
35+
):
36+
yield
37+
38+
39+
@contextlib.contextmanager
40+
def tf32_on(self, tf32_precision=1e-5):
41+
enabled = torch.backends.mkldnn.enabled
42+
deterministic = torch.backends.mkldnn.deterministic
43+
old_precision = self.precision
44+
try:
45+
self.precision = tf32_precision
46+
with torch.backends.mkldnn.flags(
47+
enabled=enabled, deterministic=deterministic, allow_tf32=True
48+
):
49+
yield
50+
finally:
51+
self.precision = old_precision
52+
53+
54+
# This is a wrapper that wraps a test to run this test twice, one with
55+
# allow_tf32=True, another with allow_tf32=False. When running with
56+
# allow_tf32=True, it will use reduced precision as specified by the
57+
# argument. For example:
58+
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
59+
# @tf32_on_and_off(0.005)
60+
# def test_matmul(self, device, dtype):
61+
# a = ...; b = ...;
62+
# c = torch.matmul(a, b)
63+
# self.assertEqual(c, expected)
64+
# In the above example, when testing torch.float32 , the matmul will be running at
65+
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
66+
# precision to check values.
67+
#
68+
# This decorator can be used for function with or without device/dtype, such as
69+
# @tf32_on_and_off(0.005)
70+
# def test_my_op(self)
71+
# @tf32_on_and_off(0.005)
72+
# def test_my_op(self, device)
73+
# @tf32_on_and_off(0.005)
74+
# def test_my_op(self, device, dtype)
75+
# @tf32_on_and_off(0.005)
76+
# def test_my_op(self, dtype)
77+
def tf32_on_and_off(tf32_precision=1e-5):
78+
def with_tf32_disabled(self, function_call):
79+
with tf32_off():
80+
function_call()
81+
82+
def with_tf32_enabled(self, function_call):
83+
with tf32_on(self, tf32_precision):
84+
function_call()
85+
86+
def wrapper(f):
87+
params = inspect.signature(f).parameters
88+
arg_names = tuple(params.keys())
89+
90+
@functools.wraps(f)
91+
def wrapped(*args, **kwargs):
92+
kwargs.update(zip(arg_names, args))
93+
cond = True
94+
if "device" in kwargs:
95+
cond = cond and (torch.device(kwargs["device"]).type == "xpu")
96+
if "dtype" in kwargs:
97+
cond = cond and (
98+
kwargs["dtype"] in {torch.float32}
99+
) # TODO: add complex64
100+
if cond:
101+
with_tf32_disabled(kwargs["self"], lambda: f(**kwargs))
102+
with_tf32_enabled(kwargs["self"], lambda: f(**kwargs))
103+
else:
104+
f(**kwargs)
105+
106+
return wrapped
107+
108+
return wrapper
109+
110+
111+
# This is a wrapper that wraps a test to run it with TF32 turned off.
112+
# This wrapper is designed to be used when a test uses matmul or convolutions
113+
# but the purpose of that test is not testing matmul or convolutions.
114+
# Disabling TF32 will enforce torch.float tensors to be always computed
115+
# at full precision.
116+
def with_tf32_off(f):
117+
@functools.wraps(f)
118+
def wrapped(*args, **kwargs):
119+
with tf32_off():
120+
return f(*args, **kwargs)
121+
122+
return wrapped
123+
124+
26125
class TestBasicGEMM(TestCase):
27126
def _test_addmm_addmv(
28127
self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None
@@ -133,11 +232,13 @@ def maybe_transpose(cond, m):
133232

134233
@precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1})
135234
@dtypes(torch.float32, torch.half, torch.double)
235+
@tf32_on_and_off(0.05)
136236
def test_addmm(self, device, dtype):
137237
self._test_addmm_impl(torch.addmm, None, device, dtype)
138238

139239
@precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4})
140240
@dtypes(torch.bfloat16, torch.half, torch.float, torch.double)
241+
@tf32_on_and_off(0.005)
141242
def test_addmv(self, device, dtype):
142243
# have to use torch.randn(...).to(bfloat16) instead of
143244
# torch.randn(..., dtype=bfloat16). randn does not support
@@ -185,6 +286,7 @@ def test_addmv(self, device, dtype):
185286
torch.float32,
186287
torch.float64,
187288
)
289+
@tf32_on_and_off(0.05)
188290
def test_mm(self, device, dtype):
189291
def _test_mm(n, m, p, dtype, genf):
190292
# helper function
@@ -287,6 +389,7 @@ def genf_Half(x, y):
287389

288390
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
289391
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64)
392+
@tf32_on_and_off(0.05)
290393
def test_bmm(self, device, dtype):
291394
batch_sizes = [1, 10]
292395
M, N, O = 23, 15, 12
@@ -403,6 +506,7 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
403506

404507
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
405508
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
509+
@tf32_on_and_off(0.005)
406510
def test_addbmm(self, device, dtype):
407511
num_batches = 2
408512
M, N, O = 16, 17, 18
@@ -506,6 +610,7 @@ def generate_tensor():
506610

507611
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6})
508612
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
613+
@tf32_on_and_off(0.01)
509614
def test_baddbmm(self, device, dtype):
510615
num_batches = 10
511616
M, N, O = 12, 8, 50
@@ -568,6 +673,7 @@ def generate_tensor():
568673
for b1, b2, ref, out_tensor in generate_tensor():
569674
self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
570675

676+
@tf32_on_and_off(0.05)
571677
def test_tensordot(self, device):
572678
a = torch.arange(60.0, device=device).reshape(3, 4, 5)
573679
b = torch.arange(24.0, device=device).reshape(4, 3, 2)
@@ -604,6 +710,7 @@ def test_tensordot(self, device):
604710

605711
@dtypes(torch.float, torch.double)
606712
@precisionOverride({torch.float32: 1e-4})
713+
@tf32_on_and_off(0.005)
607714
def test_1_sized_with_0_strided(self, device, dtype):
608715
a = make_tensor((8, 1, 64), dtype=dtype, device=device)
609716
a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
@@ -646,6 +753,7 @@ def _select_broadcastable_dims(self, dims_full=None):
646753
dims_small = [ds] + dims_small
647754
return (dims_small, dims_large, dims_full)
648755

756+
@tf32_on_and_off(0.005)
649757
def test_broadcast_fused_matmul(self, device):
650758
fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
651759

@@ -692,6 +800,7 @@ def dims_full_for_fn():
692800
self.assertEqual(r0, r1)
693801

694802
@dtypes(torch.float32, torch.float64)
803+
@tf32_on_and_off(0.005)
695804
def test_strided_mm_bmm(self, device, dtype):
696805
# Tests strided view case with stride smaller than corresponding dimension size
697806
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device)
@@ -706,6 +815,7 @@ def test_strided_mm_bmm(self, device, dtype):
706815
torch_fn = lambda x: torch.mm(x, x) # noqa: E731
707816
self.compare_with_numpy(torch_fn, np_fn, sx[0])
708817

818+
@tf32_on_and_off(0.005)
709819
def test_mm_empty_inputs_mixed_dtype_errors(self, device):
710820
a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
711821
b = torch.randn(10, 20, dtype=torch.float32, device=device)
@@ -714,6 +824,7 @@ def test_mm_empty_inputs_mixed_dtype_errors(self, device):
714824
):
715825
torch.mm(a, b)
716826

827+
@tf32_on_and_off(0.005)
717828
def test_matmul_45724(self, device):
718829
# https://github.com/pytorch/pytorch/issues/45724
719830
a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
@@ -731,6 +842,7 @@ def test_matmul_45724(self, device):
731842
torch.float32,
732843
torch.float64,
733844
)
845+
@tf32_on_and_off(0.005)
734846
def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
735847
batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
736848
batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
@@ -745,6 +857,7 @@ def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
745857
self.assertEqual(out, y_ref)
746858

747859
@dtypes(torch.float)
860+
@tf32_on_and_off(0.005)
748861
def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
749862
for shape in [[3, 2, 2], [2, 20, 20]]:
750863
mat1, mat2 = (
@@ -767,6 +880,7 @@ def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
767880

768881
@precisionOverride({torch.double: 1e-6})
769882
@dtypes(torch.float, torch.double)
883+
@tf32_on_and_off(0.005)
770884
def test_addmm_sizes(self, device, dtype):
771885
for m in [0, 1, 25]:
772886
for n in [0, 1, 10]:
@@ -798,6 +912,7 @@ def test_addmm_sizes(self, device, dtype):
798912
}
799913
)
800914
@dtypes(torch.double, torch.float32, torch.bfloat16, torch.half)
915+
@tf32_on_and_off(0.05)
801916
def test_addmm_gelu(self, device, dtype):
802917
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
803918

@@ -812,10 +927,12 @@ def test_addmm_gelu(self, device, dtype):
812927
}
813928
)
814929
@dtypes(torch.double, torch.float32, torch.bfloat16, torch.half)
930+
@tf32_on_and_off(0.05)
815931
def test_addmm_relu(self, device, dtype):
816932
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
817933

818-
@dtypes(torch.float, torch.bfloat16, torch.half, torch.double)
934+
@dtypes(torch.float, torch.bfloat16, torch.half)
935+
@tf32_on_and_off(0.005)
819936
def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
820937
# tests (o, s)*(s). o is output size, s is summed size.
821938
o = 5
@@ -859,6 +976,7 @@ def _test(row_major, incx, incy, lda_tail):
859976
}
860977
)
861978
@dtypes(torch.double, torch.bfloat16, torch.half, torch.float32)
979+
@tf32_on_and_off(0.005)
862980
def test_corner_cases_of_cublasltmatmul(self, device, dtype):
863981
# common case
864982
M = torch.randn(128, device=device).to(dtype)
@@ -998,6 +1116,7 @@ def call_torch_fn(*args, **kwargs):
9981116
torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True)
9991117
)
10001118

1119+
@tf32_on_and_off(0.005)
10011120
def test_large_bmm_backward(self, device):
10021121
A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
10031122
B = torch.randn([1, 1024, 65536], device=device, requires_grad=True)
@@ -1006,6 +1125,7 @@ def test_large_bmm_backward(self, device):
10061125
# Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
10071126
(A @ B).backward(G)
10081127

1128+
@tf32_on_and_off(0.005)
10091129
def test_large_bmm_mm_backward(self, device):
10101130
A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
10111131
B = torch.randn([1024, 65536], device=device, requires_grad=True)
@@ -1104,6 +1224,7 @@ def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
11041224
self.check_single_matmul(x, y)
11051225

11061226
@dtypes(torch.float)
1227+
@tf32_on_and_off(0.005)
11071228
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
11081229
a = torch.empty(
11091230
(256, 512), device=device, dtype=dtype, requires_grad=True

0 commit comments

Comments
 (0)