|
1 | 1 | import importlib.metadata
|
2 | 2 | import os
|
| 3 | +import re |
3 | 4 | import hashlib
|
4 | 5 | import shutil
|
5 | 6 | import ctypes
|
|
13 | 14 | from triton.runtime.cache import get_cache_manager
|
14 | 15 | from triton.backends.compiler import GPUTarget
|
15 | 16 | from triton.backends.driver import DriverBase
|
| 17 | +from triton.tools.tensor_descriptor import TensorDescriptor |
16 | 18 |
|
17 | 19 | # A hard-coded cache version that can be updated when we know that the cached file is invalid and
|
18 | 20 | # there are no other ways to detect that the runtime environment has changed. For example, a shared
|
@@ -370,10 +372,40 @@ def ty_to_cpp(ty):
|
370 | 372 |
|
371 | 373 | def make_launcher(constants, signature):
|
372 | 374 |
|
373 |
| - def _serialize_signature(sig): |
| 375 | + def _expand_signature(signature): |
| 376 | + output = [] |
| 377 | + # Expand tensor descriptor arguments into base pointer, shape, and |
| 378 | + # strides |
| 379 | + for sig in signature: |
| 380 | + if isinstance(sig, str) and sig.startswith("tensordesc"): |
| 381 | + match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig) |
| 382 | + dtype = match.group(1) |
| 383 | + shape = match.group(2) |
| 384 | + ndim = shape.count(",") + 1 |
| 385 | + |
| 386 | + output.append("*" + dtype) |
| 387 | + # Currently the host side tensor descriptors get passed in as a |
| 388 | + # tensor desc, shape, and strides. We have no way to use these |
| 389 | + # shape and strides when processing tensor descriptors which is |
| 390 | + # why we provide our own decomposition above. Sadly this means |
| 391 | + # we have to pass the shape and strides twice. |
| 392 | + for _ in range(2 * ndim): |
| 393 | + output.append("i64") |
| 394 | + for _ in range(ndim): |
| 395 | + output.append("i32") |
| 396 | + for _ in range(ndim): |
| 397 | + output.append("i64") |
| 398 | + else: |
| 399 | + output.append(sig) |
| 400 | + |
| 401 | + return output |
| 402 | + |
| 403 | + def _flatten_signature(sig, output): |
374 | 404 | if isinstance(sig, tuple):
|
375 |
| - return ','.join(map(_serialize_signature, sig)) |
376 |
| - return sig |
| 405 | + for x in sig: |
| 406 | + _flatten_signature(x, output) |
| 407 | + else: |
| 408 | + output.append(sig) |
377 | 409 |
|
378 | 410 | def _extracted_type(ty):
|
379 | 411 | if isinstance(ty, tuple):
|
@@ -408,11 +440,16 @@ def format_of(ty):
|
408 | 440 | "uint64_t": "K",
|
409 | 441 | }[ty_to_cpp(ty)]
|
410 | 442 |
|
| 443 | + expand_signature = _expand_signature(signature.values()) |
| 444 | + signature = {i: s for i, s in enumerate(expand_signature)} |
| 445 | + |
411 | 446 | args_format = ''.join([format_of(ty) for ty in signature.values()])
|
412 | 447 | format = _BASE_ARGS_FORMAT + args_format
|
413 |
| - signature = ','.join(map(_serialize_signature, signature.values())) |
414 |
| - signature = list(filter(bool, signature.split(','))) |
415 |
| - signature = {i: s for i, s in enumerate(signature)} |
| 448 | + |
| 449 | + flat_signature = [] |
| 450 | + for sig in signature.values(): |
| 451 | + _flatten_signature(sig, flat_signature) |
| 452 | + signature = {i: s for i, s in enumerate(flat_signature)} |
416 | 453 | args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
417 | 454 | # Record the end of regular arguments;
|
418 | 455 | # subsequent arguments are architecture-specific descriptors.
|
@@ -632,9 +669,10 @@ def format_of(ty):
|
632 | 669 | PyObject* py_kernel;
|
633 | 670 |
|
634 | 671 | {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
|
635 |
| - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel, |
636 |
| - &kernel_metadata, &launch_metadata, |
637 |
| - &launch_enter_hook, &launch_exit_hook {args_list})) {{ |
| 672 | + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, |
| 673 | + &py_obj_stream, &py_kernel, |
| 674 | + &kernel_metadata, &launch_metadata, |
| 675 | + &launch_enter_hook, &launch_exit_hook{args_list})) {{ |
638 | 676 | return NULL;
|
639 | 677 | }}
|
640 | 678 |
|
@@ -703,6 +741,32 @@ def format_of(ty):
|
703 | 741 | return src
|
704 | 742 |
|
705 | 743 |
|
| 744 | +def wrap_handle_tensor_descriptor(launcher): |
| 745 | + """ |
| 746 | + Replace all tensor descriptors with the base ptr, shape, and strides |
| 747 | + """ |
| 748 | + |
| 749 | + def inner(args): |
| 750 | + meta_args = args[:len(_BASE_ARGS_FORMAT)] |
| 751 | + raw_kernel_args = args[len(_BASE_ARGS_FORMAT):] |
| 752 | + final_args = [] |
| 753 | + for arg in raw_kernel_args: |
| 754 | + if isinstance(arg, TensorDescriptor): |
| 755 | + # Currently the host side tensor descriptors get decomposed in |
| 756 | + # the frontend to tensor desc, shape, and strides. We have no |
| 757 | + # way to use these shape and strides when processing tensor |
| 758 | + # descriptors which is why we provide our own decomposition |
| 759 | + # above. Sadly this means we have to pass the shape and strides |
| 760 | + # twice. |
| 761 | + final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides]) |
| 762 | + else: |
| 763 | + final_args.append(arg) |
| 764 | + |
| 765 | + return launcher(meta_args + tuple(final_args)) |
| 766 | + |
| 767 | + return inner |
| 768 | + |
| 769 | + |
706 | 770 | def serialize_args(args, constants, signature):
|
707 | 771 | import torch
|
708 | 772 | import numbers
|
@@ -767,17 +831,23 @@ class XPULauncher(object):
|
767 | 831 | def __init__(self, src, metadata):
|
768 | 832 | constants = src.constants if hasattr(src, "constants") else dict()
|
769 | 833 | arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
770 |
| - self.constants = {arg_idx(idx): value for idx, value in constants.items()} |
771 |
| - self.signature = {idx: value for idx, value in src.signature.items()} |
772 |
| - src = make_launcher(self.constants, self.signature) |
| 834 | + constants = {arg_idx(idx): value for idx, value in constants.items()} |
| 835 | + signature = {idx: value for idx, value in src.signature.items()} |
| 836 | + src = make_launcher(constants, signature) |
773 | 837 | self.mod = compile_module_from_src(src=src, name="__triton_launcher")
|
| 838 | + has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values()) |
| 839 | + |
| 840 | + self.launch = wrap_handle_tensor_descriptor(self.mod.launch) if has_tensor_desc_arg else self.mod.launch |
| 841 | + |
774 | 842 | # Serialize KernelArguments for SPIR-V Runner
|
775 | 843 | self.serialize_kernel_args = knobs.intel.dump_spirv_kernel_args
|
| 844 | + self.constants = constants |
| 845 | + self.signature = signature |
776 | 846 |
|
777 |
| - def __call__(self, *args, **kwargs): |
| 847 | + def __call__(self, *args): |
778 | 848 | if self.serialize_kernel_args:
|
779 | 849 | serialize_args(args, self.constants, self.signature)
|
780 |
| - self.mod.launch(args) |
| 850 | + self.launch(args) |
781 | 851 |
|
782 | 852 |
|
783 | 853 | class XPUDriver(DriverBase):
|
|
0 commit comments