Skip to content

Commit d0e7d2e

Browse files
jianyizhpytorchmergebot
authored andcommitted
[xpu][feature][inductor] Enable pad_mm Pass on Intel GPU (pytorch#166618)
Pull Request resolved: pytorch#166618 Approved by: https://github.com/EikanWang, https://github.com/desertfire, https://github.com/jansel
1 parent 5605fce commit d0e7d2e

File tree

3 files changed

+71
-68
lines changed

3 files changed

+71
-68
lines changed

test/inductor/test_pad_mm.py

Lines changed: 62 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
1717
from torch.testing import FileCheck
1818
from torch.testing._internal.common_utils import skipIfRocm
19-
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
19+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON
2020

2121

2222
class PadMMTest(TestCase):
@@ -38,15 +38,15 @@ class Model(torch.nn.Module):
3838
def __init__(self) -> None:
3939
super().__init__()
4040
self.w = rand_strided(
41-
(K2, N), (1, K2), device="cuda", dtype=torch.float32
41+
(K2, N), (1, K2), device=GPU_TYPE, dtype=torch.float32
4242
)
4343

4444
def forward(self, a):
4545
a1 = torch.narrow(a, 1, 0, K2)
4646
return torch.mm(a1, self.w)
4747

48-
fn = Model().cuda()
49-
a = rand_strided((M, K1), (K1, 1), device="cuda", dtype=torch.float32)
48+
fn = Model().to(GPU_TYPE)
49+
a = rand_strided((M, K1), (K1, 1), device=GPU_TYPE, dtype=torch.float32)
5050
aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2
5151
torch._dynamo.mark_dynamic(a, 0)
5252
with unittest.mock.patch(
@@ -72,17 +72,17 @@ class Model(torch.nn.Module):
7272
def __init__(self) -> None:
7373
super().__init__()
7474
self.w = rand_strided(
75-
(K2, N), (1, K2), device="cuda", dtype=torch.float32
75+
(K2, N), (1, K2), device=GPU_TYPE, dtype=torch.float32
7676
)
7777

7878
def forward(self, a, b):
7979
c = torch.cat([a, b], dim=0)
8080
a1 = torch.narrow(c, 1, 0, K2)
8181
return torch.mm(a1, self.w)
8282

83-
fn = Model().cuda()
84-
a = rand_strided((M1, K1), (K1, 1), device="cuda", dtype=torch.float32)
85-
b = rand_strided((M2, K1), (K1, 1), device="cuda", dtype=torch.float32)
83+
fn = Model().to(GPU_TYPE)
84+
a = rand_strided((M1, K1), (K1, 1), device=GPU_TYPE, dtype=torch.float32)
85+
b = rand_strided((M2, K1), (K1, 1), device=GPU_TYPE, dtype=torch.float32)
8686
torch._dynamo.mark_dynamic(a, 0)
8787
torch._dynamo.mark_dynamic(b, 0)
8888
aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2
@@ -110,9 +110,9 @@ def __init__(self) -> None:
110110
def forward(self, a, b):
111111
return torch.mm(a, b)
112112

113-
fn = Model().cuda()
114-
a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
115-
b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
113+
fn = Model().to(GPU_TYPE)
114+
a = rand_strided((M, K), (K, 1), device=GPU_TYPE, dtype=torch.float32)
115+
b = rand_strided((K, N), (1, K), device=GPU_TYPE, dtype=torch.float32)
116116
aligned_k = get_padded_length(K, get_alignment_size(a)) + K
117117
torch._dynamo.mark_dynamic(b, 1)
118118
with unittest.mock.patch(
@@ -139,9 +139,9 @@ def __init__(self) -> None:
139139
def forward(self, a, b):
140140
return torch.mm(a, b)
141141

142-
fn = Model().cuda()
143-
a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
144-
b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
142+
fn = Model().to(GPU_TYPE)
143+
a = rand_strided((M, K), (K, 1), device=GPU_TYPE, dtype=torch.float32)
144+
b = rand_strided((K, N), (1, K), device=GPU_TYPE, dtype=torch.float32)
145145
# TODO: Getting the alignment right requires pattern matcher to
146146
# run on newly added nodes
147147
aligned_m = get_padded_length(M, get_alignment_size(a)) + M
@@ -168,9 +168,9 @@ def __init__(self) -> None:
168168
def forward(self, a, b):
169169
return torch.mm(a, b)
170170

171-
fn = Model().cuda()
172-
a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
173-
b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
171+
fn = Model().to(GPU_TYPE)
172+
a = rand_strided((M, K), (K, 1), device=GPU_TYPE, dtype=torch.float32)
173+
b = rand_strided((K, N), (1, K), device=GPU_TYPE, dtype=torch.float32)
174174
torch._dynamo.mark_dynamic(a, 0)
175175
torch._dynamo.mark_dynamic(a, 1)
176176
torch._dynamo.mark_dynamic(b, 0)
@@ -188,9 +188,9 @@ def test_zero_dim(self):
188188
def addmm(x, a, b):
189189
return torch.addmm(x, a, b)
190190

191-
x = torch.randn(100).cuda()
192-
a = torch.randn(0, 10).cuda()
193-
b = torch.randn(10, 100).cuda()
191+
x = torch.randn(100).to(GPU_TYPE)
192+
a = torch.randn(0, 10).to(GPU_TYPE)
193+
b = torch.randn(10, 100).to(GPU_TYPE)
194194
self.assertEqual(torch.compile(addmm)(x, a, b), addmm(x, a, b))
195195

196196
@inductor_config.patch(
@@ -209,9 +209,9 @@ def __init__(self) -> None:
209209
def forward(self, a, b):
210210
return torch.bmm(a, b)
211211

212-
fn = Model().cuda()
213-
a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
214-
b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
212+
fn = Model().to(GPU_TYPE)
213+
a = torch.randn(B, M, K, device=GPU_TYPE, dtype=torch.float32)
214+
b = torch.randn(B, K, N, device=GPU_TYPE, dtype=torch.float32)
215215
aligned_k = get_padded_length(K, get_alignment_size(a)) + K
216216
torch._dynamo.mark_dynamic(a, 0)
217217
torch._dynamo.mark_dynamic(b, 0)
@@ -240,9 +240,9 @@ def __init__(self) -> None:
240240
def forward(self, a, b):
241241
return torch.bmm(a, b)
242242

243-
fn = Model().cuda()
244-
a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
245-
b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
243+
fn = Model().to(GPU_TYPE)
244+
a = torch.randn(B, M, K, device=GPU_TYPE, dtype=torch.float32)
245+
b = torch.randn(B, K, N, device=GPU_TYPE, dtype=torch.float32)
246246
aligned_n = get_padded_length(N, get_alignment_size(b)) + N
247247
torch._dynamo.mark_dynamic(a, 2)
248248
torch._dynamo.mark_dynamic(b, 1)
@@ -271,9 +271,9 @@ def __init__(self) -> None:
271271
def forward(self, a, b):
272272
return torch.bmm(a, b)
273273

274-
fn = Model().cuda()
275-
a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
276-
b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
274+
fn = Model().to(GPU_TYPE)
275+
a = torch.randn(B, M, K, device=GPU_TYPE, dtype=torch.float32)
276+
b = torch.randn(B, K, N, device=GPU_TYPE, dtype=torch.float32)
277277
aligned_n = get_padded_length(N, get_alignment_size(b)) + N
278278
torch._dynamo.mark_dynamic(a, 0)
279279
torch._dynamo.mark_dynamic(a, 1)
@@ -302,10 +302,10 @@ def __init__(self) -> None:
302302
def forward(self, a, b, c):
303303
return torch.addmm(a, b, c)
304304

305-
fn = Model().cuda()
306-
a = torch.randn(M, N, device="cuda", dtype=torch.float32)
307-
b = torch.randn(M, K, device="cuda", dtype=torch.float32)
308-
c = torch.randn(K, N, device="cuda", dtype=torch.float32)
305+
fn = Model().to(GPU_TYPE)
306+
a = torch.randn(M, N, device=GPU_TYPE, dtype=torch.float32)
307+
b = torch.randn(M, K, device=GPU_TYPE, dtype=torch.float32)
308+
c = torch.randn(K, N, device=GPU_TYPE, dtype=torch.float32)
309309
aligned_k = get_padded_length(K, get_alignment_size(b)) + K
310310
torch._dynamo.mark_dynamic(a, 0)
311311
torch._dynamo.mark_dynamic(b, 0)
@@ -333,10 +333,10 @@ def __init__(self) -> None:
333333
def forward(self, a, b, c):
334334
return torch.addmm(a, b, c)
335335

336-
fn = Model().cuda()
337-
a = torch.randn(M, N, device="cuda", dtype=torch.float32)
338-
b = torch.randn(M, K, device="cuda", dtype=torch.float32)
339-
c = torch.randn(K, N, device="cuda", dtype=torch.float32)
336+
fn = Model().to(GPU_TYPE)
337+
a = torch.randn(M, N, device=GPU_TYPE, dtype=torch.float32)
338+
b = torch.randn(M, K, device=GPU_TYPE, dtype=torch.float32)
339+
c = torch.randn(K, N, device=GPU_TYPE, dtype=torch.float32)
340340
torch._dynamo.mark_dynamic(a, 0)
341341
torch._dynamo.mark_dynamic(a, 1)
342342
torch._dynamo.mark_dynamic(b, 0)
@@ -357,7 +357,7 @@ def test_pad_single_cat(self):
357357
def foo(x, y):
358358
return x @ y
359359

360-
inps = [torch.rand([5, 5], device="cuda") for _ in range(2)]
360+
inps = [torch.rand([5, 5], device=GPU_TYPE) for _ in range(2)]
361361
out = foo(*inps)
362362
self.assertEqual(out, inps[0] @ inps[1])
363363

@@ -371,19 +371,19 @@ def foo(input, x, y):
371371
for a in [1, 4]:
372372
for b in [1, 6]:
373373
inps = (
374-
torch.rand([a, b], device="cuda"),
375-
torch.rand([4, 5], device="cuda"),
376-
torch.rand([5, 6], device="cuda"),
374+
torch.rand([a, b], device=GPU_TYPE),
375+
torch.rand([4, 5], device=GPU_TYPE),
376+
torch.rand([5, 6], device=GPU_TYPE),
377377
)
378378
out = foo(*inps)
379379
out_eager = torch.ops.aten.addmm(*inps)
380380
self.assertEqual(out, out_eager)
381381

382382
for a in [1, 6]:
383383
inps = (
384-
torch.rand([a], device="cuda"),
385-
torch.rand([4, 5], device="cuda"),
386-
torch.rand([5, 6], device="cuda"),
384+
torch.rand([a], device=GPU_TYPE),
385+
torch.rand([4, 5], device=GPU_TYPE),
386+
torch.rand([5, 6], device=GPU_TYPE),
387387
)
388388
out = foo(*inps)
389389
out_eager = torch.ops.aten.addmm(*inps)
@@ -395,8 +395,8 @@ def test_pad_batch(self):
395395
n = 9
396396
k = 11
397397
batch_size = 3
398-
mat1 = torch.ones((batch_size, m, k), device="cuda", dtype=torch.float16)
399-
mat2 = torch.ones((batch_size, k, n), device="cuda", dtype=torch.float16)
398+
mat1 = torch.ones((batch_size, m, k), device=GPU_TYPE, dtype=torch.float16)
399+
mat2 = torch.ones((batch_size, k, n), device=GPU_TYPE, dtype=torch.float16)
400400
expected_alignment = get_alignment_size(mat1)
401401

402402
assert expected_alignment == 8, "Alignment for float16 should be 8"
@@ -413,7 +413,7 @@ def bmm(mat1, mat2):
413413
# in call code, expect to see a single pad per input, and then we should see padded allocation for output
414414
FileCheck().check("del async_compile").check_count(
415415
".run(", 2, exactly=True
416-
).check("empty_strided_cuda((3, 8, 16)").run(code)
416+
).check(f"empty_strided_{GPU_TYPE}((3, 8, 16)").run(code)
417417

418418
assert torch.allclose(res2, bmm_expected_result), (
419419
"BMM results are not identical"
@@ -425,7 +425,7 @@ def test_exclude_padding(self):
425425
def mm(a, b):
426426
return a @ b
427427

428-
mm(torch.rand([25, 25], device="cuda"), torch.rand([25, 25], device="cuda"))
428+
mm(torch.rand([25, 25], device=GPU_TYPE), torch.rand([25, 25], device=GPU_TYPE))
429429
local_cache = get_pad_cache().get_local_cache()
430430
self.assertTrue(len(local_cache) == 2)
431431
FileCheck().check_count("exclude_pad:False", 2, exactly=True).run(
@@ -436,7 +436,7 @@ def mm(a, b):
436436
def mm(a, b):
437437
return (a + 1) @ b
438438

439-
mm(torch.rand([25, 25], device="cuda"), torch.rand([25, 25], device="cuda"))
439+
mm(torch.rand([25, 25], device=GPU_TYPE), torch.rand([25, 25], device=GPU_TYPE))
440440
local_cache = get_pad_cache().get_local_cache()
441441
# reuse original base timing
442442
self.assertTrue(len(local_cache) == 3)
@@ -455,8 +455,8 @@ def test_exclude_cat_padding(self):
455455
def mm(inps, b):
456456
return torch.cat(inps) @ b
457457

458-
inp = torch.rand([2046, 2046], device="cuda")
459-
inp2 = torch.rand([2046, 2046], device="cuda")
458+
inp = torch.rand([2046, 2046], device=GPU_TYPE)
459+
inp2 = torch.rand([2046, 2046], device=GPU_TYPE)
460460

461461
inps = inp.chunk(3)
462462
mm(inps, inp2)
@@ -471,7 +471,8 @@ def mm(inps, b):
471471
)
472472

473473
@unittest.skipIf(
474-
not torch.cuda.is_available() or torch.cuda.get_device_capability() >= (9, 0),
474+
(not torch.cuda.is_available() or torch.cuda.get_device_capability() >= (9, 0))
475+
and (not torch.xpu.is_available()),
475476
"No perf regression on H100+ with BF16",
476477
)
477478
@skipIfRocm
@@ -483,8 +484,8 @@ def test_pad_mm_bf16(self):
483484
m = 2
484485
n = 13
485486
k = 15691904
486-
mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16)
487-
mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16)
487+
mat1 = torch.ones((m, k), device=GPU_TYPE, dtype=torch.bfloat16)
488+
mat2 = torch.ones((k, n), device=GPU_TYPE, dtype=torch.bfloat16)
488489
expected_alignment = get_alignment_size(mat1)
489490

490491
assert expected_alignment == 8, "Alignment for bfloat16 should be 8"
@@ -504,7 +505,7 @@ def mm(mat1, mat2):
504505
# in call code, expect to see a single pad per input, and then we should see padded allocation for output
505506
FileCheck().check("del async_compile").check_count(
506507
".run(", 2, exactly=True
507-
).check("empty_strided_cuda((8, 16)").run(code)
508+
).check(f"empty_strided_{GPU_TYPE}((8, 16)").run(code)
508509

509510
assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
510511

@@ -521,8 +522,8 @@ def fn(x, y):
521522
return x @ y
522523

523524
args = [
524-
torch.randn(2**4, 2**8 - 1, device="cuda", dtype=torch.float16),
525-
torch.randn(2**8 - 1, 2**4, device="cuda", dtype=torch.float16),
525+
torch.randn(2**4, 2**8 - 1, device=GPU_TYPE, dtype=torch.float16),
526+
torch.randn(2**8 - 1, 2**4, device=GPU_TYPE, dtype=torch.float16),
526527
]
527528

528529
counters.clear()
@@ -615,7 +616,7 @@ def test_masked_mha(B, H, S, D, device, dtype):
615616
):
616617
mha = torch.compile(mha, fullgraph=True, backend="inductor")
617618
with torch.autocast(
618-
device_type="cuda", dtype=dtype, cache_enabled=False
619+
device_type=GPU_TYPE, dtype=dtype, cache_enabled=False
619620
):
620621
out_vid = mha(x1, x2, attn_mask)
621622
target_vid = torch.randn_like(out_vid)
@@ -624,7 +625,7 @@ def test_masked_mha(B, H, S, D, device, dtype):
624625
loss = loss_vid
625626
loss.backward()
626627

627-
torch.cuda.synchronize()
628+
torch.accelerator.synchronize()
628629

629630
# Check if any bmm operations had dtype changes
630631
for node_name_pre, node_name_post in zip(
@@ -642,13 +643,13 @@ def test_masked_mha(B, H, S, D, device, dtype):
642643
self.assertFalse(torch.any(x2.grad.isnan()).item())
643644

644645
B, H, S, D = 2, 32, 549, 128
645-
device = "cuda"
646+
device = GPU_TYPE
646647
dtype = torch.bfloat16
647648
torch.compiler.reset()
648649
torch.manual_seed(42)
649650
test_masked_mha(B, H, S, D, device, dtype)
650651

651652

652653
if __name__ == "__main__":
653-
if HAS_CUDA_AND_TRITON:
654+
if HAS_GPU_AND_TRITON:
654655
run_tests()

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15893,8 +15893,6 @@ def wrapper(inp, weight):
1589315893
_, code = run_and_get_code(wrapper, inp, weight)
1589415894
self.assertTrue("in_out_ptr" in code[1])
1589515895

15896-
# TODO: Enable this case after pad_mm is enabled on XPU.
15897-
@expectedFailureXPU
1589815896
@torch._functorch.config.patch("donated_buffer", True)
1589915897
@torch._inductor.config.patch("force_shape_pad", True)
1590015898
def test_donated_buffer_inplace_gpt(self):

torch/_inductor/fx_passes/pad_mm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_alignment_size_dtype(dtype: torch.dtype) -> int:
7676

7777

7878
def check_device(a: Tensor, b: Tensor) -> bool:
79-
return a.is_cuda and b.is_cuda
79+
return (a.is_cuda and b.is_cuda) or (a.is_xpu and b.is_xpu)
8080

8181

8282
def check_dtype(a: Tensor, b: Tensor) -> bool:
@@ -225,7 +225,7 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
225225
dtype is torch.bfloat16
226226
and K > M
227227
and K > N
228-
and torch.cuda.get_device_capability() < (9, 0)
228+
and (torch.xpu.is_available() or torch.cuda.get_device_capability() < (9, 0))
229229
): # doesn't repro on h100s:
230230
return True
231231

@@ -280,7 +280,9 @@ def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]:
280280
return (t.shape, t.stride(), t.dtype)
281281

282282
tf32_key = (
283-
None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
283+
None
284+
if mat1.dtype != torch.float32
285+
else torch.backends.cuda.matmul.allow_tf32 or torch.backends.mkldnn.allow_tf32
284286
)
285287

286288
def fmt_pad(name: str) -> str | None:
@@ -381,7 +383,7 @@ def should_pad_mm_bf16(dtype: torch.dtype, M: int, N: int, K: int) -> bool:
381383
and K > N
382384
and N % 2 == 1
383385
and K >= large_k_threshold_to_pad
384-
and torch.cuda.get_device_capability() < (9, 0)
386+
and (torch.xpu.is_available() or torch.cuda.get_device_capability() < (9, 0))
385387
): # doesn't repro on h100s:
386388
return True
387389
return False
@@ -549,7 +551,7 @@ def write_pad():
549551

550552
if op is torch.ops.aten.addmm:
551553
input_pad = None
552-
if input is not None and input.is_cuda:
554+
if input is not None and (input.is_cuda or input.is_xpu):
553555
input_pad = torch.randn_like(input)
554556
fns.append(
555557
lambda: pad_addmm(
@@ -870,6 +872,8 @@ def _pad_mm_init() -> None:
870872
if torch.cuda.is_available():
871873
# workaround https://github.com/pytorch/pytorch/issues/97894
872874
device = "cuda"
875+
elif torch.xpu.is_available():
876+
device = "xpu"
873877
else:
874878
device = "cpu"
875879

0 commit comments

Comments
 (0)