Skip to content

Commit ca7e8ba

Browse files
authored
Generic SPIR-V Runner (#2258)
This PR implements serialization/de-serialization of kernel arguments and meta data from Triton backend and enables fully functional generic SPIRVRunner.
1 parent 8d5374a commit ca7e8ba

File tree

8 files changed

+382
-139
lines changed

8 files changed

+382
-139
lines changed

third_party/intel/backend/driver.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,19 +435,78 @@ def format_of(ty):
435435
return src
436436

437437

438+
def serialize_kernel_metadata(arg, args_dict):
439+
args_dict['num_warps'] = arg.num_warps
440+
args_dict['threads_per_warp'] = arg.threads_per_warp
441+
args_dict['shared_memory'] = arg.shared
442+
args_dict['kernel_name'] = arg.name
443+
args_dict['spv_name'] = f"{arg.name}.spv"
444+
445+
446+
def serialize_args(args, constants, signature):
447+
import torch
448+
import numbers
449+
dir_path = os.getenv('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS')
450+
if not os.path.exists(dir_path):
451+
os.makedirs(dir_path)
452+
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
453+
454+
cnt = 0
455+
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
456+
args_dict['argument_list'] = []
457+
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
458+
cnt = 4
459+
for arg in args[cnt:]:
460+
if type(arg).__name__ == "KernelMetadata":
461+
serialize_kernel_metadata(arg, args_dict)
462+
463+
if isinstance(arg, torch.Tensor):
464+
cpu_tensor = arg.cpu()
465+
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
466+
with open(tensor_path, 'wb') as f:
467+
torch.save(cpu_tensor, f)
468+
new_arg = {
469+
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
470+
signature[counts['karg_cnt']]
471+
}
472+
args_dict['argument_list'].append(new_arg)
473+
counts['karg_cnt'] += 1
474+
counts['tensors'] += 1
475+
476+
if isinstance(arg, numbers.Number):
477+
if counts['karg_cnt'] not in constants:
478+
new_arg = {
479+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
480+
signature[counts['karg_cnt']]
481+
}
482+
args_dict['argument_list'].append(new_arg)
483+
counts['karg_cnt'] += 1
484+
counts['scalars'] += 1
485+
cnt += 1
486+
# Dump argument info as a JSON file
487+
json_path = os.path.join(dir_path, 'args_data.json')
488+
with open(json_path, 'w') as json_file:
489+
import json
490+
json.dump(args_dict, json_file, indent=4)
491+
492+
438493
class XPULauncher(object):
439494

440495
def __init__(self, src, metadata):
441496
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
442497
constants = src.constants if hasattr(src, "constants") else dict()
443498
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
444-
constants = {cst_key(key): value for key, value in constants.items()}
445-
signature = {cst_key(key): value for key, value in src.signature.items()}
446-
src = make_launcher(constants, signature, ids)
499+
self.constants = {cst_key(key): value for key, value in constants.items()}
500+
self.signature = {cst_key(key): value for key, value in src.signature.items()}
501+
src = make_launcher(self.constants, self.signature, ids)
447502
mod = compile_module_from_src(src, "__triton_launcher")
448503
self.launch = mod.launch
449504

450505
def __call__(self, *args, **kwargs):
506+
# Serialize KernelArguments for SPIR-V Runner
507+
serialize_kernel_args = os.getenv('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS', None)
508+
if serialize_kernel_args:
509+
serialize_args(args, self.constants, self.signature)
451510
self.launch(*args, **kwargs)
452511

453512

utils/SPIRVRunner/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
22
project(reproducer)
3-
43
set(CMAKE_CXX_COMPILER icpx)
54
set(BUILD_SHARED_LIBS OFF)
65

76
list(APPEND CMAKE_PREFIX_PATH "/opt/intel/oneapi/tbb/latest/lib/cmake/tbb/")
87

98
find_package(Torch REQUIRED)
109

10+
include(ExternalProject)
11+
ExternalProject_Add(
12+
json
13+
GIT_REPOSITORY https://github.com/nlohmann/json.git
14+
GIT_TAG v3.11.2
15+
PREFIX ${CMAKE_BINARY_DIR}/nlohmann_json
16+
CONFIGURE_COMMAND ""
17+
BUILD_COMMAND ""
18+
INSTALL_COMMAND ""
19+
)
20+
set(JSON_INCLUDE_DIR ${CMAKE_BINARY_DIR}/nlohmann_json/src/json/include/)
21+
1122
# Add preview-breaking-changes for ABI compatibility with SYCL library linked by PyTorch: https://github.com/pytorch/pytorch/commit/92bebb46fa9fd60523d8aeb7b5f1a3f488c4cd93
1223
set(COMPILE_FLAGS "-fsycl -Wall -fpreview-breaking-changes")
1324
set(LINK_FLAGS "-fsycl -lze_loader")
@@ -16,9 +27,10 @@ set(SYCL_FUNCTIONS_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/in
1627

1728
set(TARGET_NAME SPIRVRunner)
1829
add_executable(${TARGET_NAME} ${TARGET_NAME}.cpp)
19-
target_include_directories(${TARGET_NAME} PRIVATE "/opt/intel/oneapi/compiler/latest/include" ${SYCL_FUNCTIONS_INCLUDE_DIR})
30+
target_include_directories(${TARGET_NAME} PRIVATE "/opt/intel/oneapi/compiler/latest/include" ${SYCL_FUNCTIONS_INCLUDE_DIR} ${JSON_INCLUDE_DIR})
2031
set_target_properties(${TARGET_NAME} PROPERTIES COMPILE_FLAGS "${COMPILE_FLAGS}")
2132
set_target_properties(${TARGET_NAME} PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
33+
add_dependencies(${TARGET_NAME} json)
2234

2335
target_link_libraries(${TARGET_NAME} "${TORCH_LIBRARIES}")
2436
set_property(TARGET ${TARGET_NAME} PROPERTY CXX_STANDARD 17)

utils/SPIRVRunner/README.md

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,44 @@ CMAKE_PREFIX_PATH=/abs/path/to/TorchConfig.cmake/FromAbove/ cmake -DCMAKE_BUILD_
1717
make -j
1818
```
1919

20-
## Configuring
20+
## Configuration
2121

22-
`SPIRVRunner` is configured to run the `add_kernel.spv` SPIRV binary with inputs `x.py` and `y.py`. `add_kernel.spv` was generated from the `01-vector-add.py` tutorial.
22+
### Generate Data
2323

24-
Kernels of different shapes require modifying parameters manually in the `SPIRVRunner`. Two places require modification:
24+
In order to utilize this utility, Triton application must be run with following environment variables enabled
25+
Provide the path to the directory where the serialized JSON, tensors and SPRI-V binary stored. It is recommended to clear triton cache.
26+
27+
```
28+
export TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS=< Absolute path to SPV Dumps >
29+
```
30+
31+
Following input data is generated,
32+
33+
1. args_data.json - (Kernel Arguments / Grid Configuration)
34+
2. tensors (Tensors used by the kernel (.pt))
35+
3. SPIR-V binary (.spv)
2536

26-
1. `launchKernel`: Add input Tensors to the function signature, add arguments as variables within the function. Arguments can be pulled from the `args` variable to `XPULauncher.__call__` method in `driver.py`. Arguments should be passed to the `sycl_kernel_launch` function. Note that we currently rely on `sycl::memcpy` to move the PyTorch Tensor to XPU. In later versions of PyTorch we should be able to delegate this responsibility to `PyTorch`, and pass the raw XPU `data_ptr()` from `PyTorch` to the kernel.
27-
2. `sycl_kernel_launch`: Place all `arg*` parameters into the `params` array and add an appropriate call to `set_scalar_arg` for each param, which tells `SYCL` what the arguments are for the kernel we are going to launch.
2837

2938
## Running
3039

31-
Once the `SPIRVRunner` has been appropriately configured for the kernel and inputs, run the binary with no arguments:
40+
Help:
41+
`./build/SPIRVRunner` < Output Tensor Name >
42+
43+
Note: `Output Tensor Name` is essentially a chosen tensor that needs to be copied back to the CPU and written to disk. Additionally, the name must match the tensor's name (tensor_) and number as specified in the JSON file. Please refer args_data.json file.
44+
45+
### Demo (01-vector-add.py)
46+
47+
`SPIRVRunner` is configured to run the `add_kernel.spv` SPIRV binary with inputs `tensor_0.pt` and `tensor_1.pt` and output `tensor_2.pt`. `add_kernel.spv` was generated from the `01-vector-add.py` tutorial.
3248

33-
`./build/SPIRVRunner`
49+
SPIRVRunner Usage:
50+
`./build/SPIRVRunner tensor_2`
3451

3552
Expected output follows:
3653

3754
```
3855
Running on device: Intel(R) Data Center GPU Max 1100
39-
Tensor a: [98432], Float (393728 bytes)
40-
Tensor b: [98432], Float (393728 bytes)
4156
Read 3772 byte kernel.
57+
create kernel:add_kernel
4258
Loaded kernel with 0 registers and 0 register spills.
4359
Tensor output: [98432], Float (393728 bytes)
4460
Kernel return output: 1.37129

0 commit comments

Comments
 (0)