Skip to content

Commit 02c26fd

Browse files
authored
Workflow to test SPIRVRunner (#3400)
1 parent 95bba24 commit 02c26fd

File tree

4 files changed

+191
-5
lines changed

4 files changed

+191
-5
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
name: Test SPIRVRunner
2+
3+
on:
4+
workflow_dispatch:
5+
6+
pull_request:
7+
branches:
8+
- main
9+
push:
10+
branches:
11+
- main
12+
13+
permissions: read-all
14+
15+
env:
16+
PYTHON_VERSION: '3.9'
17+
18+
jobs:
19+
tests:
20+
name: Tests
21+
runs-on:
22+
- rolling
23+
- runner-0.0.22
24+
steps:
25+
- name: Checkout repository
26+
uses: actions/checkout@v4
27+
28+
- name: Install Python
29+
uses: actions/setup-python@v5
30+
with:
31+
python-version: ${{ env.PYTHON_VERSION }}
32+
33+
- name: Setup PyTorch
34+
uses: ./.github/actions/setup-pytorch
35+
36+
- name: Setup Triton
37+
uses: ./.github/actions/setup-triton
38+
39+
- name: Build SPIRVRunner
40+
run: |
41+
source /opt/intel/oneapi/setvars.sh
42+
set -x
43+
export LLVM_DIR="$HOME/.triton/llvm/llvm-ubuntu-x64"
44+
export CMAKE_PREFIX_PATH="$(python scripts/torch_cmake.py)"
45+
cd utils/SPIRVRunner
46+
mkdir build
47+
cd build
48+
cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo ..
49+
make -j
50+
51+
- name: Test SPIRVRunner
52+
run: |
53+
source /opt/intel/oneapi/setvars.sh
54+
set -x
55+
export SPIRV_RUNNER_PATH="$GITHUB_WORKSPACE/utils/SPIRVRunner/build/SPIRVRunner"
56+
export SPIRV_RUNNER_TESTS="$GITHUB_WORKSPACE/utils/SPIRVRunner/tests"
57+
cd utils/SPIRVRunner
58+
pytest tests/test_spirv_runner.py

scripts/torch_cmake.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Prints cmake directory for PyTorch."""
2+
3+
import importlib.metadata
4+
import pathlib
5+
6+
7+
def get_torch_cmake_path() -> pathlib.Path:
8+
"""Returns directory that contains TorchConfig.cmake.
9+
10+
Raises:
11+
importlib.metadata.PackageNotFoundError: if torch not installed.
12+
AssertionError: if TorchConfig.cmake not found.
13+
"""
14+
files = importlib.metadata.files('torch') or []
15+
for f in files:
16+
if f.name == 'TorchConfig.cmake':
17+
return pathlib.Path(f.locate()).parent.resolve()
18+
raise AssertionError('TorchConfig.cmake not found')
19+
20+
21+
if __name__ == '__main__':
22+
print(get_torch_cmake_path())

utils/SPIRVRunner/README.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,26 @@ A utility program for running Triton-generated SPIR-V kernels with identical inp
44

55
## Building
66

7-
`SPIRVRunner` depends on Torch. If you build Triton with virtualenvs, you can easily find your torch library path by running
7+
`SPIRVRunner` depends on Torch.
8+
9+
If you build Triton with venv, you can easily find your torch library path by running the following command in the top level Triton directory:
10+
811
```
912
find .venv -name TorchConfig.cmake
1013
```
11-
in the top level Triton directory.
1214

13-
`SPIRVRunner` depends on LLVM support libarary for argument parsing in order to use this run following in the top level Triton directory.
15+
Alternatively, you can find `TorchConfig.cmake` with the following Python script:
16+
17+
```python
18+
import importlib.metadata
19+
20+
for f in importlib.metadata.files('torch'):
21+
if f.name == 'TorchConfig.cmake':
22+
print(f.locate().resolve())
23+
```
24+
25+
`SPIRVRunner` depends on LLVM support library for argument parsing in order to use this run following in the top level Triton directory.
26+
1427
```
1528
scripts/compile-triton.sh --llvm
1629
```
@@ -20,7 +33,7 @@ SPIR-V Runner build steps:
2033
```
2134
mkdir build
2235
cd build
23-
CMAKE_PREFIX_PATH=/abs/path/to/TorchConfig.cmake/FromAbove/ LLVM_DIR=/abs/path/to/packages/llvm cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo ..
36+
CMAKE_PREFIX_PATH=/abs/path/to/TorchConfig.cmake/directory LLVM_DIR=/abs/path/to/packages/llvm cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo ..
2437
make -j
2538
```
2639

@@ -29,7 +42,7 @@ make -j
2942
### Generate Data
3043

3144
In order to utilize this utility, Triton application must be run with following environment variables enabled
32-
Provide the path to the directory where the serialized JSON, tensors and SPRI-V binary stored. It is recommended to clear triton cache.
45+
Provide the path to the directory where the serialized JSON, tensors and SPIR-V binary stored. It is recommended to clear triton cache.
3346

3447
```
3548
export TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS=< Absolute path to SPV Dumps >
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Vector Addition
3+
===============
4+
5+
In this tutorial, you will write a simple vector addition using Triton.
6+
7+
In doing so, you will learn about:
8+
9+
* The basic programming model of Triton.
10+
11+
* The `triton.jit` decorator, which is used to define Triton kernels.
12+
13+
* The best practices for validating and benchmarking your custom ops against native reference implementations.
14+
15+
"""
16+
17+
# %%
18+
# Compute Kernel
19+
# --------------
20+
21+
import torch
22+
23+
import triton
24+
import triton.language as tl
25+
26+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
27+
28+
29+
@triton.jit
30+
def add_kernel(x_ptr, # *Pointer* to first input vector.
31+
y_ptr, # *Pointer* to second input vector.
32+
output_ptr, # *Pointer* to output vector.
33+
n_elements, # Size of the vector.
34+
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
35+
# NOTE: `constexpr` so it can be used as a shape value.
36+
):
37+
# There are multiple 'programs' processing different data. We identify which program
38+
# we are here:
39+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
40+
# This program will process inputs that are offset from the initial data.
41+
# For instance, if you had a vector of length 256 and block_size of 64, the programs
42+
# would each access the elements [0:64, 64:128, 128:192, 192:256].
43+
# Note that offsets is a list of pointers:
44+
block_start = pid * BLOCK_SIZE
45+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
46+
# Create a mask to guard memory operations against out-of-bounds accesses.
47+
mask = offsets < n_elements
48+
# Load x and y from DRAM, masking out any extra elements in case the input is not a
49+
# multiple of the block size.
50+
x = tl.load(x_ptr + offsets, mask=mask)
51+
y = tl.load(y_ptr + offsets, mask=mask)
52+
output = x + y
53+
# Write x + y back to DRAM.
54+
tl.store(output_ptr + offsets, output, mask=mask)
55+
56+
57+
# %%
58+
# Let's also declare a helper function to (1) allocate the `z` tensor
59+
# and (2) enqueue the above kernel with appropriate grid/block sizes:
60+
61+
62+
def add(x: torch.Tensor, y: torch.Tensor):
63+
# We need to preallocate the output.
64+
output = torch.empty_like(x)
65+
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
66+
n_elements = output.numel()
67+
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
68+
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
69+
# In this case, we use a 1D grid where the size is the number of blocks:
70+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
71+
# NOTE:
72+
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
73+
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
74+
# - Don't forget to pass meta-parameters as keywords arguments.
75+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
76+
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
77+
# running asynchronously at this point.
78+
return output
79+
80+
81+
# %%
82+
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
83+
84+
torch.manual_seed(0)
85+
size = 98432
86+
x = torch.rand(size, device=DEVICE)
87+
y = torch.rand(size, device=DEVICE)
88+
output_torch = x + y
89+
output_triton = add(x, y)
90+
print(output_torch.cpu())
91+
print(output_triton.cpu())
92+
print(f'The maximum difference between torch and triton is '
93+
f'{torch.max(torch.abs(output_torch.cpu() - output_triton.cpu()))}')

0 commit comments

Comments
 (0)