Skip to content

Commit 58e3f28

Browse files
committed
add xegpu matmul example
1 parent 72809e9 commit 58e3f28

File tree

6 files changed

+1333
-0
lines changed

6 files changed

+1333
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# XeGPU matrix multiplication benchmark
2+
3+
## Installation
4+
5+
### 1. GPU Drivers and Level Zero
6+
7+
Install Intel GPU drivers and Level Zero runtime on your system.
8+
9+
### 2. Compile LLVM with Intel GPU support
10+
11+
To use Lighthouse with Intel GPUs, LLVM must be built with LevelZero runtime.
12+
13+
Set up a Python environment and install Python packages:
14+
15+
```bash
16+
pip install pybind11 nanobind PyYAML numpy
17+
```
18+
19+
Set `LLVM_INSTALL_DIR` and use the below script to checkout and compile LLVM locally.
20+
21+
```bash
22+
export LLVM_INSTALL_DIR=<...>
23+
LLVM_VERSION=83765f435d1c
24+
git checkout https://github.com/llvm/llvm-project.git -b $LLVM_VERSION
25+
26+
cd llvm-project
27+
mkdir -p build
28+
cd build
29+
30+
cmake ../llvm -G Ninja \
31+
-DCMAKE_BUILD_TYPE=Release \
32+
-DLLVM_ENABLE_PROJECTS=mlir \
33+
-DLLVM_BUILD_EXAMPLES=OFF \
34+
-DLLVM_TARGETS_TO_BUILD="host" \
35+
-DLLVM_ENABLE_ASSERTIONS=ON \
36+
-DLLVM_ENABLE_RTTI=ON \
37+
-DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" \
38+
-DLLVM_INSTALL_GTEST=ON \
39+
-DMLIR_ENABLE_LEVELZERO_RUNNER=1 \
40+
-DMLIR_ENABLE_BINDINGS_PYTHON=1 \
41+
-DPython3_EXECUTABLE=$(which python3) \
42+
-DLLVM_INSTALL_UTILS=ON \
43+
-DCMAKE_INSTALL_PREFIX=${INSTALL_DIR}
44+
cmake --build .
45+
cmake --install .
46+
```
47+
48+
If cmake cannot find LevelZero, set environment variable `LEVEL_ZERO_DIR=<path-to-level-zero-install-root>`.
49+
50+
### Install Lighthouse
51+
52+
Install Lighthouse as instructed in the main [README](../../../../README.md).
53+
54+
Override the default LLVM package by setting `PYTHONPATH` to the local LLVM Python bindings:
55+
56+
```bash
57+
export PYTHONPATH=${LLVM_INSTALL_DIR}/python_packages/mlir_core
58+
```
59+
60+
## Usage
61+
62+
Run the default 4k (float16, float16) -> float32 matrix multiplication benchmark with correctness test:
63+
64+
```bash
65+
python matmul.py --check-result
66+
```
67+
68+
Set different M, N, K problem size
69+
70+
```bash
71+
python matmul.py --sizes 1024 2048 4096 ...
72+
```
73+
74+
Run with ReLU post-op:
75+
76+
```bash
77+
python matmul.py --relu ...
78+
```
79+
80+
See all command line arguments:
81+
82+
```bash
83+
python matmul.py --help
84+
```
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import numpy as np
2+
import ctypes
3+
import os
4+
from typing import Optional
5+
6+
from mlir.dialects.transform import interpreter as transform_interpreter
7+
from mlir.dialects import func, arith, scf, memref
8+
from mlir.execution_engine import ExecutionEngine
9+
from mlir import ir
10+
from mlir.runtime.np_to_memref import get_ranked_memref_descriptor
11+
12+
from lighthouse.utils import get_packed_arg
13+
from mlir_utils import get_mlir_library_path
14+
15+
16+
def get_engine(payload_module, opt_level=3) -> ExecutionEngine:
17+
context = ir.Context()
18+
location = ir.Location.unknown(context)
19+
lib_dir = get_mlir_library_path()
20+
libs = [
21+
"libmlir_levelzero_runtime.so",
22+
"libmlir_runner_utils.so",
23+
"libmlir_c_runner_utils.so",
24+
]
25+
libs = [os.path.join(lib_dir, lib) for lib in libs]
26+
with context, location:
27+
execution_engine = ExecutionEngine(
28+
payload_module, opt_level=opt_level, shared_libs=libs
29+
)
30+
execution_engine.initialize()
31+
return execution_engine
32+
33+
34+
def apply_transform_schedule(
35+
payload_module,
36+
schedule_module,
37+
context,
38+
location,
39+
dump_kernel: Optional[str] = None,
40+
dump_schedule: bool = False,
41+
):
42+
if not dump_kernel or dump_kernel != "initial":
43+
with context, location:
44+
# invoke transform interpreter directly
45+
transform_interpreter.apply_named_sequence(
46+
payload_root=payload_module,
47+
transform_root=schedule_module.body.operations[0],
48+
transform_module=schedule_module,
49+
)
50+
if dump_kernel:
51+
print(payload_module)
52+
if dump_schedule:
53+
print(schedule_module)
54+
55+
56+
def lower_payload(
57+
workload,
58+
dump_kernel: Optional[str] = None,
59+
dump_schedule: bool = False,
60+
schedule_parameters: Optional[dict] = None,
61+
) -> ir.Module:
62+
payload_module = workload.payload_module()
63+
schedule_module = workload.schedule_module(
64+
dump_kernel=dump_kernel, parameters=schedule_parameters
65+
)
66+
apply_transform_schedule(
67+
payload_module,
68+
schedule_module,
69+
workload.context,
70+
workload.location,
71+
dump_kernel=dump_kernel,
72+
dump_schedule=dump_schedule,
73+
)
74+
return payload_module
75+
76+
77+
def execute(
78+
workload,
79+
check_correctness: bool = True,
80+
schedule_parameters: Optional[dict] = None,
81+
verbose: int = 0,
82+
):
83+
# lower payload with schedule
84+
payload_module = lower_payload(workload, schedule_parameters=schedule_parameters)
85+
# get execution engine
86+
engine = get_engine(payload_module, requirements=workload.requirements())
87+
88+
with workload.allocate(execution_engine=engine):
89+
# prepare function arguments
90+
inputs = workload.get_input_arrays(execution_engine=engine)
91+
pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs]
92+
packed_args = get_packed_arg(pointers)
93+
94+
# handle to payload function
95+
payload_func = engine.lookup(workload.payload_function_name)
96+
97+
# call
98+
payload_func(packed_args)
99+
100+
if check_correctness:
101+
workload.check_correctness(execution_engine=engine, verbose=verbose)
102+
103+
104+
def benchmark(
105+
workload,
106+
nruns: int = 100,
107+
nwarmup: int = 10,
108+
schedule_parameters: Optional[dict] = None,
109+
check_correctness: bool = True,
110+
verbose: int = 0,
111+
) -> np.ndarray:
112+
# get original payload module
113+
payload_module = workload.payload_module()
114+
115+
# find payload function
116+
payload_func = None
117+
for op in payload_module.operation.regions[0].blocks[0]:
118+
if (
119+
isinstance(op, func.FuncOp)
120+
and str(op.name).strip('"') == workload.payload_function_name
121+
):
122+
payload_func = op
123+
break
124+
assert payload_func is not None, "Could not find payload function"
125+
payload_arguments = payload_func.type.inputs
126+
127+
# emit benchmark function that calls payload and times it
128+
with workload.context, workload.location:
129+
with ir.InsertionPoint(payload_module.body):
130+
# define rtclock function
131+
f64_t = ir.F64Type.get()
132+
f = func.FuncOp("rtclock", ((), (f64_t,)), visibility="private")
133+
# emit benchmark function
134+
time_memref_t = ir.MemRefType.get((nruns,), f64_t)
135+
args = payload_arguments + [time_memref_t]
136+
f = func.FuncOp("benchmark", (tuple(args), ()))
137+
f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
138+
with ir.InsertionPoint(f.add_entry_block()):
139+
index_t = ir.IndexType.get()
140+
zero = arith.ConstantOp(index_t, 0)
141+
one = arith.ConstantOp(index_t, 1)
142+
nwarmup_cst = arith.ConstantOp(index_t, nwarmup)
143+
for_op = scf.ForOp(zero, nwarmup_cst, one)
144+
with ir.InsertionPoint(for_op.body):
145+
func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)]))
146+
scf.YieldOp(())
147+
nruns_cst = arith.ConstantOp(index_t, nruns)
148+
for_op = scf.ForOp(zero, nruns_cst, one)
149+
i = for_op.induction_variable
150+
with ir.InsertionPoint(for_op.body):
151+
tic = func.CallOp((f64_t,), "rtclock", ()).result
152+
func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)]))
153+
toc = func.CallOp((f64_t,), "rtclock", ()).result
154+
time = arith.SubFOp(toc, tic)
155+
memref.StoreOp(time, f.arguments[-1], [i])
156+
scf.YieldOp(())
157+
func.ReturnOp(())
158+
159+
# lower
160+
apply_transform_schedule(
161+
payload_module,
162+
workload.schedule_module(parameters=schedule_parameters),
163+
workload.context,
164+
workload.location,
165+
)
166+
# get execution engine, rtclock requires mlir_c_runner
167+
engine = get_engine(payload_module)
168+
169+
with workload.allocate(execution_engine=engine):
170+
inputs = workload.get_input_arrays(execution_engine=engine)
171+
pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs]
172+
if check_correctness:
173+
# call payload once to verify correctness
174+
# prepare function arguments
175+
packed_args = get_packed_arg(pointers)
176+
177+
payload_func = engine.lookup(workload.payload_function_name)
178+
payload_func(packed_args)
179+
success = workload.check_correctness(
180+
execution_engine=engine, verbose=verbose
181+
)
182+
if not success:
183+
raise ValueError("Benchmark verification failed.")
184+
185+
# allocate buffer for timings and prepare arguments
186+
time_array = np.zeros((nruns,), dtype=np.float64)
187+
time_memref = get_ranked_memref_descriptor(time_array)
188+
time_pointer = ctypes.pointer(ctypes.pointer(time_memref))
189+
packed_args_with_time = get_packed_arg(pointers + [time_pointer])
190+
191+
# call benchmark function
192+
benchmark_func = engine.lookup("benchmark")
193+
benchmark_func(packed_args_with_time)
194+
195+
return time_array

0 commit comments

Comments
 (0)