Skip to content

Commit 17a1e56

Browse files
committed
[intel] align driver after 'dba08d4' and xfail 'test_tensor_descriptor_padding'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 2d934e7 commit 17a1e56

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
383383

384384

385385
@pytest.mark.interpreter
386-
def test_tensor_descriptor_padding():
386+
def test_tensor_descriptor_padding(device):
387+
if not is_cuda():
388+
pytest.xfail("padding is unsupported")
387389

388390
@triton.jit
389391
def device_tma_load(in_ptr, out_ptr, IM, IN, YM, YN, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr,
@@ -414,7 +416,7 @@ def host_tma_load(in_desc, out_ptr, YM, YN, M_BLOCK: tl.constexpr, N_BLOCK: tl.c
414416

415417
# TMA descriptors require a global memory allocation
416418
def alloc_fn(size: int, alignment: float, stream: float):
417-
return torch.ones(size, device="cuda", dtype=torch.float32)
419+
return torch.ones(size, device=device, dtype=torch.float32)
418420

419421
triton.set_allocator(alloc_fn)
420422

@@ -423,16 +425,16 @@ def alloc_fn(size: int, alignment: float, stream: float):
423425
M_BLOCK = 32
424426
N_BLOCK = 32
425427
padding = "nan"
426-
input = torch.arange(IM * IN, device="cuda", dtype=torch.float32)
428+
input = torch.arange(IM * IN, device=device, dtype=torch.float32)
427429
input = input.reshape(IM, IN)
428-
out_device_tma = torch.zeros((OM, ON), device="cuda", dtype=torch.float32)
429-
out_host_tma = torch.zeros((OM, ON), device="cuda", dtype=torch.float32)
430+
out_device_tma = torch.zeros((OM, ON), device=device, dtype=torch.float32)
431+
out_host_tma = torch.zeros((OM, ON), device=device, dtype=torch.float32)
430432
dummy_block = [M_BLOCK, N_BLOCK]
431433
in_desc = TensorDescriptor(input, input.shape, input.stride(), dummy_block, padding=padding)
432434
grid = (triton.cdiv(OM, M_BLOCK), triton.cdiv(ON, N_BLOCK))
433435
device_tma_load[grid](input, out_device_tma, IM, IN, OM, ON, M_BLOCK, N_BLOCK, padding)
434436
host_tma_load[grid](in_desc, out_host_tma, OM, ON, M_BLOCK, N_BLOCK)
435-
expected = torch.zeros((OM, ON), device="cuda", dtype=torch.float32)
437+
expected = torch.zeros((OM, ON), device=device, dtype=torch.float32)
436438
expected[0:IN, 0:IM] = input
437439
expected[:, IN:ON] = float('nan')
438440
expected[IM:OM, :] = float('nan')

third_party/intel/backend/driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def _expand_signature(signature):
391391
# we have to pass the shape and strides twice.
392392
for _ in range(2 * ndim):
393393
output.append("i64")
394+
output.append("i1")
394395
for _ in range(ndim):
395396
output.append("i32")
396397
for _ in range(ndim):
@@ -797,7 +798,7 @@ def inner(args):
797798
# descriptors which is why we provide our own decomposition
798799
# above. Sadly this means we have to pass the shape and strides
799800
# twice.
800-
final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides])
801+
final_args.extend([arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides])
801802
else:
802803
final_args.append(arg)
803804

0 commit comments

Comments
 (0)