Skip to content

Commit 69e14ad

Browse files
authored
Unwrap tensor descriptor kernel argument into its members, and update launch code. (#4820)
Fixes issue #4289. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 77894ef commit 69e14ad

File tree

2 files changed

+84
-20
lines changed

2 files changed

+84
-20
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,8 +1507,6 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
15071507
pytest.skip("Broken on rocm")
15081508
if is_xpu():
15091509
if (kind, dtype_str) in [("add", "bfloat16")]:
1510-
if descriptor == "host":
1511-
pytest.skip("FIXME: issue #4289")
15121510
pytest.skip("FIXME: issue #3914")
15131511

15141512
@triton.jit(debug=True)
@@ -1593,8 +1591,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
15931591
def test_host_tensor_descriptor_load(dtype_str, num_ctas, M_BLOCK, N_BLOCK, device):
15941592
if num_ctas == 2 and (not is_cuda() or torch.cuda.get_device_capability(0)[0] not in (9, 10)):
15951593
pytest.xfail("CTAs is unsupported for these cards")
1596-
if is_xpu():
1597-
pytest.skip("FIXME: issue #4289")
15981594

15991595
@triton.jit(debug=True)
16001596
def kernel(out_ptr, desc, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
@@ -1658,8 +1654,6 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B
16581654

16591655
if is_hip() and (BLOCK_M, BLOCK_N, BLOCK_K, num_stages) == (256, 128, 32, 4):
16601656
pytest.skip("Insufficient shared memory on HIP devices")
1661-
if is_xpu():
1662-
pytest.skip("FIXME: issue #4289")
16631657

16641658
if is_interpreter():
16651659
M, N, K = BLOCK_M, BLOCK_N, BLOCK_K

third_party/intel/backend/driver.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import importlib.metadata
22
import os
3+
import re
34
import hashlib
45
import shutil
56
import ctypes
@@ -13,6 +14,7 @@
1314
from triton.runtime.cache import get_cache_manager
1415
from triton.backends.compiler import GPUTarget
1516
from triton.backends.driver import DriverBase
17+
from triton.tools.tensor_descriptor import TensorDescriptor
1618

1719
# A hard-coded cache version that can be updated when we know that the cached file is invalid and
1820
# there are no other ways to detect that the runtime environment has changed. For example, a shared
@@ -370,10 +372,40 @@ def ty_to_cpp(ty):
370372

371373
def make_launcher(constants, signature):
372374

373-
def _serialize_signature(sig):
375+
def _expand_signature(signature):
376+
output = []
377+
# Expand tensor descriptor arguments into base pointer, shape, and
378+
# strides
379+
for sig in signature:
380+
if isinstance(sig, str) and sig.startswith("tensordesc"):
381+
match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
382+
dtype = match.group(1)
383+
shape = match.group(2)
384+
ndim = shape.count(",") + 1
385+
386+
output.append("*" + dtype)
387+
# Currently the host side tensor descriptors get passed in as a
388+
# tensor desc, shape, and strides. We have no way to use these
389+
# shape and strides when processing tensor descriptors which is
390+
# why we provide our own decomposition above. Sadly this means
391+
# we have to pass the shape and strides twice.
392+
for _ in range(2 * ndim):
393+
output.append("i64")
394+
for _ in range(ndim):
395+
output.append("i32")
396+
for _ in range(ndim):
397+
output.append("i64")
398+
else:
399+
output.append(sig)
400+
401+
return output
402+
403+
def _flatten_signature(sig, output):
374404
if isinstance(sig, tuple):
375-
return ','.join(map(_serialize_signature, sig))
376-
return sig
405+
for x in sig:
406+
_flatten_signature(x, output)
407+
else:
408+
output.append(sig)
377409

378410
def _extracted_type(ty):
379411
if isinstance(ty, tuple):
@@ -408,11 +440,16 @@ def format_of(ty):
408440
"uint64_t": "K",
409441
}[ty_to_cpp(ty)]
410442

443+
expand_signature = _expand_signature(signature.values())
444+
signature = {i: s for i, s in enumerate(expand_signature)}
445+
411446
args_format = ''.join([format_of(ty) for ty in signature.values()])
412447
format = _BASE_ARGS_FORMAT + args_format
413-
signature = ','.join(map(_serialize_signature, signature.values()))
414-
signature = list(filter(bool, signature.split(',')))
415-
signature = {i: s for i, s in enumerate(signature)}
448+
449+
flat_signature = []
450+
for sig in signature.values():
451+
_flatten_signature(sig, flat_signature)
452+
signature = {i: s for i, s in enumerate(flat_signature)}
416453
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
417454
# Record the end of regular arguments;
418455
# subsequent arguments are architecture-specific descriptors.
@@ -632,9 +669,10 @@ def format_of(ty):
632669
PyObject* py_kernel;
633670
634671
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
635-
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
636-
&kernel_metadata, &launch_metadata,
637-
&launch_enter_hook, &launch_exit_hook {args_list})) {{
672+
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
673+
&py_obj_stream, &py_kernel,
674+
&kernel_metadata, &launch_metadata,
675+
&launch_enter_hook, &launch_exit_hook{args_list})) {{
638676
return NULL;
639677
}}
640678
@@ -703,6 +741,32 @@ def format_of(ty):
703741
return src
704742

705743

744+
def wrap_handle_tensor_descriptor(launcher):
745+
"""
746+
Replace all tensor descriptors with the base ptr, shape, and strides
747+
"""
748+
749+
def inner(args):
750+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
751+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
752+
final_args = []
753+
for arg in raw_kernel_args:
754+
if isinstance(arg, TensorDescriptor):
755+
# Currently the host side tensor descriptors get decomposed in
756+
# the frontend to tensor desc, shape, and strides. We have no
757+
# way to use these shape and strides when processing tensor
758+
# descriptors which is why we provide our own decomposition
759+
# above. Sadly this means we have to pass the shape and strides
760+
# twice.
761+
final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides])
762+
else:
763+
final_args.append(arg)
764+
765+
return launcher(meta_args + tuple(final_args))
766+
767+
return inner
768+
769+
706770
def serialize_args(args, constants, signature):
707771
import torch
708772
import numbers
@@ -767,17 +831,23 @@ class XPULauncher(object):
767831
def __init__(self, src, metadata):
768832
constants = src.constants if hasattr(src, "constants") else dict()
769833
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
770-
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
771-
self.signature = {idx: value for idx, value in src.signature.items()}
772-
src = make_launcher(self.constants, self.signature)
834+
constants = {arg_idx(idx): value for idx, value in constants.items()}
835+
signature = {idx: value for idx, value in src.signature.items()}
836+
src = make_launcher(constants, signature)
773837
self.mod = compile_module_from_src(src=src, name="__triton_launcher")
838+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
839+
840+
self.launch = wrap_handle_tensor_descriptor(self.mod.launch) if has_tensor_desc_arg else self.mod.launch
841+
774842
# Serialize KernelArguments for SPIR-V Runner
775843
self.serialize_kernel_args = knobs.intel.dump_spirv_kernel_args
844+
self.constants = constants
845+
self.signature = signature
776846

777-
def __call__(self, *args, **kwargs):
847+
def __call__(self, *args):
778848
if self.serialize_kernel_args:
779849
serialize_args(args, self.constants, self.signature)
780-
self.mod.launch(args)
850+
self.launch(args)
781851

782852

783853
class XPUDriver(DriverBase):

0 commit comments

Comments
 (0)