diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 3396fb1..b889a85 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -28,4 +28,4 @@ jobs: - name: Run lit-enabled examples as tests run: | export FILECHECK=FileCheck-18 # Ubuntu's llvm-dev appends a version number. - uv run lit python/examples # Makes sure to substitute FileCheck for $FILECHECK + uv run lit python/examples --verbose # Makes sure to substitute FileCheck for $FILECHECK diff --git a/python/examples/xegpu_matmul/README.md b/python/examples/xegpu_matmul/README.md new file mode 100644 index 0000000..0070c62 --- /dev/null +++ b/python/examples/xegpu_matmul/README.md @@ -0,0 +1,84 @@ +# XeGPU matrix multiplication benchmark + +## Installation + +### 1. GPU Drivers and Level Zero + +Install Intel GPU drivers and Level Zero runtime on your system. + +### 2. Compile LLVM with Intel GPU support + +To use Lighthouse with Intel GPUs, LLVM must be built with LevelZero runtime. + +Set up a Python environment and install Python packages: + +```bash +pip install pybind11 nanobind PyYAML numpy +``` + +Set `LLVM_INSTALL_DIR` and use the below script to checkout and compile LLVM locally. + +```bash +export LLVM_INSTALL_DIR=<...> +LLVM_VERSION=83765f435d1c +git checkout https://github.com/llvm/llvm-project.git -b $LLVM_VERSION + +cd llvm-project +mkdir -p build +cd build + +cmake ../llvm -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_BUILD_EXAMPLES=OFF \ + -DLLVM_TARGETS_TO_BUILD="host" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" \ + -DLLVM_INSTALL_GTEST=ON \ + -DMLIR_ENABLE_LEVELZERO_RUNNER=1 \ + -DMLIR_ENABLE_BINDINGS_PYTHON=1 \ + -DPython3_EXECUTABLE=$(which python3) \ + -DLLVM_INSTALL_UTILS=ON \ + -DCMAKE_INSTALL_PREFIX=${INSTALL_DIR} +cmake --build . +cmake --install . +``` + +If cmake cannot find LevelZero, set environment variable `LEVEL_ZERO_DIR=`. + +### Install Lighthouse + +Install Lighthouse as instructed in the main [README](../../../README.md). + +Override the default LLVM package by setting `PYTHONPATH` to the local LLVM Python bindings: + +```bash +export PYTHONPATH=${LLVM_INSTALL_DIR}/python_packages/mlir_core +``` + +## Usage + +Run the default 4k (float16, float16) -> float32 matrix multiplication benchmark with correctness test: + +```bash +python matmul.py --check-result +``` + +Set different M, N, K problem size + +```bash +python matmul.py --sizes 1024 2048 4096 ... +``` + +Run with ReLU post-op: + +```bash +python matmul.py --relu ... +``` + +See all command line arguments: + +```bash +python matmul.py --help +``` diff --git a/python/examples/xegpu_matmul/lit.local.cfg b/python/examples/xegpu_matmul/lit.local.cfg new file mode 100644 index 0000000..b310830 --- /dev/null +++ b/python/examples/xegpu_matmul/lit.local.cfg @@ -0,0 +1 @@ +config.excludes = ["mlir_utils.py", "payload.py", "runner.py", "schedule.py"] diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py new file mode 100644 index 0000000..32b397f --- /dev/null +++ b/python/examples/xegpu_matmul/matmul.py @@ -0,0 +1,410 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU matrix multiplication benchmark. +""" + +import numpy as np +from mlir import ir +from mlir.runtime.np_to_memref import ( + get_ranked_memref_descriptor, + make_nd_memref_descriptor, + as_ctype, +) +from mlir.execution_engine import ExecutionEngine +from typing import Optional +import ctypes +from contextlib import contextmanager +from functools import cached_property +from lighthouse.utils import get_packed_arg, memref_to_ctype +from schedule import get_schedule_module +from payload import generate_matmul_payload + +from runner import lower_payload, benchmark +import argparse + + +def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: + """Convert numpy array to memref and ctypes **void pointer.""" + return memref_to_ctype(get_ranked_memref_descriptor(arr)) + + +class XeGPUMatMul: + """ + Matrix multiplication workload on XeGPU. + + Computes C = A * B for input matrices A (M x K) and B (K x N). + + Optionally adds a ReLU operation on the result C. + Optionally adds a bias term to C (not implemented yet). + """ + + payload_function_name = "payload" + + def __init__( + self, + M: int, + N: int, + K: int, + ab_type: str = "f16", + c_type: str = "f32", + has_bias: bool = False, + has_relu: bool = False, + ): + self.M = M + self.N = N + self.K = K + self.a_shape = (M, K) + self.b_shape = (K, N) + self.c_shape = (M, N) + assert ab_type == "f16", "Only f16 type is supported for A and B" + assert c_type == "f32", "Only f32 type is supported for C" + self.ab_type = ab_type + self.c_type = c_type + type_str_to_numpy = { + "f16": np.float16, + "f32": np.float32, + } + self.ab_dtype = type_str_to_numpy[ab_type] + self.c_dtype = type_str_to_numpy[c_type] + self.has_bias = has_bias + self.has_relu = has_relu + if has_bias: + raise NotImplementedError("Bias is not implemented yet") + # cache allocated memrefs + self.gpu_memrefs = {} + + def _allocate_array( + self, + name: str, + shape: tuple[int, ...], + dtype_str: str, + execution_engine: ExecutionEngine, + ) -> ctypes.Structure: + key = (name, dtype_str) + if key in self.gpu_memrefs: + return self.gpu_memrefs[key] + dtype = { + "f16": np.float16, + "f32": np.float32, + }[dtype_str] + alloc_func = execution_engine.lookup("gpu_alloc_" + dtype_str) + mref = make_nd_memref_descriptor(len(shape), as_ctype(dtype))() + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] + alloc_func(get_packed_arg([ptr_mref] + ptr_dims)) + self.gpu_memrefs[key] = mref + return mref + + def _deallocate_all(self, execution_engine: ExecutionEngine): + for (_, dtype_str), mref in self.gpu_memrefs.items(): + dealloc_func = execution_engine.lookup("gpu_dealloc_" + dtype_str) + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + dealloc_func(get_packed_arg([ptr_mref])) + self.gpu_memrefs = {} + + @contextmanager + def allocate_inputs(self, execution_engine: ExecutionEngine): + try: + inputs = self._get_input_arrays(execution_engine) + yield inputs + finally: + self._deallocate_all(execution_engine) + + @cached_property + def _initial_host_arrays(self) -> list[np.ndarray]: + """Generate initial values on host with numpy.""" + + # use integer values to avoid f16/f32 floating point discrepancies + def gen_random(shape, dtype): + # generate values in range [-3, 3] + a = np.round(6 * np.random.random_sample(shape)) - 3 + return a.astype(dtype) + + np.random.seed(2) + A = gen_random((self.M, self.K), self.ab_dtype) + B = gen_random((self.K, self.N), self.ab_dtype) + C = gen_random((self.M, self.N), self.c_dtype) + return A, B, C + + @cached_property + def _reference_solution(self) -> np.ndarray: + """Compute reference solution on host with numpy.""" + A, B, C = self._initial_host_arrays + # use float32 data type for efficiency + f32 = np.float32 + C_ref = A.astype(f32) @ B.astype(f32) + C.astype(f32) + if self.has_relu: + C_ref = np.maximum(C_ref, 0) + if self.has_bias: + raise NotImplementedError("Bias verification not implemented") + return C_ref + + def _get_input_arrays( + self, execution_engine: ExecutionEngine + ) -> list[ctypes.Structure]: + A_gpu = self._allocate_array("A", self.a_shape, self.ab_type, execution_engine) + B_gpu = self._allocate_array("B", self.b_shape, self.ab_type, execution_engine) + C_gpu = self._allocate_array("C", self.c_shape, self.c_type, execution_engine) + + A_host, B_host, C_host = self._initial_host_arrays + # copy initial values to device + copy_func_ab = execution_engine.lookup("gpu_copy_" + self.ab_type) + copy_func_c = execution_engine.lookup("gpu_copy_" + self.c_type) + copy_func_ab(get_packed_arg([numpy_to_ctype(A_host), memref_to_ctype(A_gpu)])) + copy_func_ab(get_packed_arg([numpy_to_ctype(B_host), memref_to_ctype(B_gpu)])) + copy_func_c(get_packed_arg([numpy_to_ctype(C_host), memref_to_ctype(C_gpu)])) + + # return memrefs for the payload function + return [A_gpu, B_gpu, C_gpu] + + def check_correctness( + self, execution_engine: ExecutionEngine, verbose: int = 0 + ) -> bool: + # copy result from device to host + C_gpu = self.gpu_memrefs[("C", self.c_type)] + C_host_copy = np.zeros((self.M, self.N), dtype=self.c_dtype) + copy_func = execution_engine.lookup("gpu_copy_" + self.c_type) + copy_func(get_packed_arg([memref_to_ctype(C_gpu), numpy_to_ctype(C_host_copy)])) + + C_host_ref = self._reference_solution + C_host = C_host_copy.astype(np.float32) + if verbose > 1: + print("Reference solution:") + print(C_host_ref) + print("Computed solution:") + print(C_host) + success = np.allclose(C_host, C_host_ref) + + if verbose: + if success: + print("PASSED") + else: + print("FAILED Result mismatch!") + return success + + def get_complexity(self) -> tuple[int, int, int]: + M, N, K = self.M, self.N, self.K + flop_count = 2 * M * N * K + if self.has_bias: + flop_count += M * N + if self.has_relu: + flop_count += M * N + nbytes_ab = np.dtype(self.ab_dtype).itemsize + nbytes_c = np.dtype(self.c_dtype).itemsize + memory_reads = (M * K + K * N) * nbytes_ab # read A and B + memory_writes = M * N * nbytes_c # write C + return (flop_count, memory_reads, memory_writes) + + def payload_module(self) -> ir.Module: + mod = generate_matmul_payload( + func_name=self.payload_function_name, + M=self.M, + N=self.N, + K=self.K, + ab_type_str=self.ab_type, + c_type_str=self.c_type, + has_bias=self.has_bias, + has_relu=self.has_relu, + ) + return mod + + def schedule_module( + self, dump_kernel: str = None, parameters: Optional[dict] = None + ) -> ir.Module: + return get_schedule_module( + has_bias=self.has_bias, + has_relu=self.has_relu, + dump_kernel=dump_kernel, + params=parameters, + ) + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Matrix Multiplication using MLIR", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--sizes", + type=int, + nargs=3, + default=[4096, 4096, 4096], + help="M,N,K matrix sizes (A=MxK, B=KxN, C=MxN).", + ) + parser.add_argument( + "--wg-tile", + type=int, + nargs=2, + default=[256, 256], + help="Workgroup tile size M,N.", + ) + parser.add_argument( + "--sg-tile", + type=int, + nargs=2, + default=[32, 32], + help="Subgroup tile size M,N.", + ) + parser.add_argument( + "--k-tile", + type=int, + default=32, + help="Inner reduction dimension tile size K.", + ) + parser.add_argument( + "--load-tile-a", + type=int, + nargs=2, + default=[32, 16], + help="Tile size for loading A matrix for DPAS op.", + ) + parser.add_argument( + "--load-tile-b", + type=int, + nargs=2, + default=[32, 16], + help="Tile size for loading B matrix for DPAS op.", + ) + parser.add_argument( + "--prefetch-tile-a", + type=int, + nargs=2, + default=[8, 32], + help="Tile size for cooperative prefetching of subgroup A matrix", + ) + parser.add_argument( + "--prefetch-tile-b", + type=int, + nargs=2, + default=[8, 16], + help="Tile size for cooperative prefetching of subgroup B matrix", + ) + parser.add_argument( + "--nb-prefetch", + type=int, + default=1, + help="Number of initial prefetches.", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--relu", + action="store_true", + help="Add relu op after the matrix multiplication (and bias if any).", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the matrix multiplication.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "tiled", + "vectorized", + "bufferized", + "xegpu-initial", + "xegpu-wg", + "xegpu-sg", + "xegpu-inst", + "final", + ], + help="Dump kernel IR at different stages of lowering.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "auto_wg_d0": args.wg_tile[0], + "auto_wg_d1": args.wg_tile[1], + "auto_sg_d0": args.sg_tile[0], + "auto_sg_d1": args.sg_tile[1], + "auto_k": args.k_tile, + "auto_load_a_d0": args.load_tile_a[0], + "auto_load_a_d1": args.load_tile_a[1], + "auto_load_b_d0": args.load_tile_b[0], + "auto_load_b_d1": args.load_tile_b[1], + "auto_prefetch_a_d0": args.prefetch_tile_a[0], + "auto_prefetch_a_d1": args.prefetch_tile_a[1], + "auto_prefetch_b_d0": args.prefetch_tile_b[0], + "auto_prefetch_b_d1": args.prefetch_tile_b[1], + "auto_nb_prefetch": args.nb_prefetch, + } + + M, N, K = args.sizes + ab_type = "f16" + c_type = "f32" + + with ir.Context(), ir.Location.unknown(): + wload = XeGPUMatMul( + M=M, + N=N, + K=K, + ab_type=ab_type, + c_type=c_type, + has_bias=False, + has_relu=args.relu, + ) + + if args.dump_kernel or args.dump_schedule: + lower_payload( + wload, + dump_kernel=args.dump_kernel, + dump_schedule=args.dump_schedule, + schedule_parameters=params, + ) + else: + times = benchmark( + wload, + nruns=args.nruns, + nwarmup=args.nwarmup, + check_correctness=args.check_result, + schedule_parameters=params, + verbose=1, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + def list2str(a): + return ",".join(map(str, a)) + + parts = [ + f"sizes={list2str(args.sizes)}", + f"dt={ab_type},{c_type}", + f"wg-tile={list2str(args.wg_tile)}", + f"sg-tile={list2str(args.sg_tile)}", + f"k-tile={args.k_tile}", + f"load-a-tile={list2str(args.load_tile_a)}", + f"load-b-tile={list2str(args.load_tile_b)}", + f"pf-a-tile={list2str(args.prefetch_tile_a)}", + f"pf-b-tile={list2str(args.prefetch_tile_b)}", + f"time(us): {elapsed:.2f}", + f"GFLOPS: {gflops:.2f}", + ] + print(" ".join(parts)) diff --git a/python/examples/xegpu_matmul/mlir_utils.py b/python/examples/xegpu_matmul/mlir_utils.py new file mode 100644 index 0000000..2b5a4ba --- /dev/null +++ b/python/examples/xegpu_matmul/mlir_utils.py @@ -0,0 +1,29 @@ +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured +import os + + +def apply_registered_pass(*args, **kwargs): + return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs) + + +def match(*args, **kwargs): + return structured.structured_match(transform.AnyOpType.get(), *args, **kwargs) + + +def canonicalize(op): + with ir.InsertionPoint(transform.apply_patterns(op).patterns): + transform.apply_patterns_canonicalization() + + +def get_mlir_library_path(): + pkg_path = ir.__file__ + if "python_packages" in pkg_path: + # looks like a local mlir install + path = pkg_path.split("python_packages")[0] + os.sep + "lib" + else: + # maybe installed in python path + path = os.path.split(pkg_path)[0] + os.sep + "_mlir_libs" + assert os.path.isdir(path) + return path diff --git a/python/examples/xegpu_matmul/payload.py b/python/examples/xegpu_matmul/payload.py new file mode 100644 index 0000000..0cf3a45 --- /dev/null +++ b/python/examples/xegpu_matmul/payload.py @@ -0,0 +1,124 @@ +from mlir import ir +from mlir.dialects import func, linalg, gpu, bufferization, arith, tensor + + +def emit_gpu_alloc(suffix: str, element_type: ir.Type, rank: int = 2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + index_t = ir.IndexType.get() + i32_t = ir.IntegerType.get_signless(32) + inputs = rank * (i32_t,) + + @func.func(*inputs, name="gpu_alloc_" + suffix) + def alloc_func(*shape): + dims = [arith.index_cast(index_t, a) for a in shape] + alloc = gpu.alloc(memref_dyn_t, None, [], dims, []) + return alloc + + alloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +def emit_gpu_dealloc(suffix: str, element_type: ir.Type, rank: int = 2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + + @func.func(memref_dyn_t, name="gpu_dealloc_" + suffix) + def dealloc_func(memref): + gpu.dealloc(None, [], memref) + + dealloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +def emit_gpu_copy(suffix: str, element_type: ir.Type, rank: int = 2): + """Emit GPU copy function.""" + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + + @func.func(memref_dyn_t, memref_dyn_t, name="gpu_copy_" + suffix) + def copy_func(src, dst): + gpu.memcpy(None, [], dst, src) + + copy_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +def emit_gpu_util_funcs(element_type: ir.Type): + """Emit GPU utility functions for allocation, deallocation and copy.""" + suffix = { + ir.F16Type.get(): "f16", + ir.F32Type.get(): "f32", + }[element_type] + emit_gpu_alloc(suffix, element_type) + emit_gpu_dealloc(suffix, element_type) + emit_gpu_copy(suffix, element_type) + + +def generate_matmul_payload( + func_name: str, + M: int, + N: int, + K: int, + ab_type_str: str, + c_type_str: str, + has_bias: bool, + has_relu: bool, +) -> ir.Module: + """Generate payload function module.""" + get_ir_dtype = { + "f16": ir.F16Type.get(), + "f32": ir.F32Type.get(), + } + ab_type = get_ir_dtype[ab_type_str] + c_type = get_ir_dtype[c_type_str] + tensor_a_t = ir.RankedTensorType.get((M, K), ab_type) + tensor_b_t = ir.RankedTensorType.get((K, N), ab_type) + tensor_c_t = ir.RankedTensorType.get((M, N), c_type) + tensor_bias_t = ir.RankedTensorType.get((N,), c_type) + memref_a_t = ir.MemRefType.get((M, K), ab_type) + memref_b_t = ir.MemRefType.get((K, N), ab_type) + memref_c_t = ir.MemRefType.get((M, N), c_type) + memref_bias_t = ir.MemRefType.get((N,), c_type) + mod = ir.Module.create() + with ir.InsertionPoint(mod.body): + fargs = [memref_a_t, memref_b_t] + if has_bias: + fargs.append(memref_bias_t) + fargs.append(memref_c_t) + + @func.func(*fargs, name=func_name) + def payload(*args): + A = args[0] + B = args[1] + C = args[-1] + a_tensor = bufferization.to_tensor(tensor_a_t, A, restrict=True) + b_tensor = bufferization.to_tensor(tensor_b_t, B, restrict=True) + c_tensor = bufferization.to_tensor( + tensor_c_t, C, restrict=True, writable=True + ) + + mmul = linalg.matmul(a_tensor, b_tensor, outs=[c_tensor]) + terminal = mmul + if has_bias: + bias = args[2] + bias_tensor = bufferization.to_tensor( + tensor_bias_t, bias, restrict=True, writable=True + ) + empty = tensor.empty((M, N), c_type) + bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) + terminal = linalg.add(bcast, terminal, outs=[empty]) + if has_relu: + zero = arith.constant(c_type, 0.0) + empty = tensor.empty((M, N), c_type) + zero_tensor = linalg.fill(zero, outs=[empty]) + terminal = linalg.max(terminal, zero_tensor, outs=[empty]) + + bufferization.materialize_in_destination( + None, terminal, C, restrict=True, writable=True + ) + + payload.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + emit_gpu_util_funcs(ab_type) + if c_type != ab_type: + emit_gpu_util_funcs(c_type) + + return mod diff --git a/python/examples/xegpu_matmul/runner.py b/python/examples/xegpu_matmul/runner.py new file mode 100644 index 0000000..f9b0bbd --- /dev/null +++ b/python/examples/xegpu_matmul/runner.py @@ -0,0 +1,177 @@ +import numpy as np +import ctypes +import os +from typing import Optional + +from mlir.dialects import func, arith, scf, memref +from mlir.execution_engine import ExecutionEngine +from mlir import ir +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor + +from lighthouse.utils import get_packed_arg +from mlir_utils import get_mlir_library_path + + +def get_engine(payload_module: ir.Module, opt_level: int = 3) -> ExecutionEngine: + lib_dir = get_mlir_library_path() + libs = [ + "libmlir_levelzero_runtime.so", + "libmlir_runner_utils.so", + "libmlir_c_runner_utils.so", + ] + libs = [os.path.join(lib_dir, lib) for lib in libs] + execution_engine = ExecutionEngine( + payload_module, opt_level=opt_level, shared_libs=libs + ) + execution_engine.initialize() + return execution_engine + + +def apply_transform_schedule( + payload_module: ir.Module, + schedule_module: ir.Module, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, +): + if not dump_kernel or dump_kernel != "initial": + # apply schedule on payload module + named_seq = schedule_module.body.operations[0] + named_seq.apply(payload_module) + if dump_kernel: + print(payload_module) + if dump_schedule: + print(schedule_module) + + +def lower_payload( + workload: object, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, + schedule_parameters: Optional[dict] = None, +) -> ir.Module: + payload_module = workload.payload_module() + schedule_module = workload.schedule_module( + dump_kernel=dump_kernel, parameters=schedule_parameters + ) + apply_transform_schedule( + payload_module, + schedule_module, + dump_kernel=dump_kernel, + dump_schedule=dump_schedule, + ) + return payload_module + + +def execute( + workload: object, + check_correctness: bool = True, + schedule_parameters: Optional[dict] = None, + verbose: int = 0, +): + # lower payload with schedule + payload_module = lower_payload(workload, schedule_parameters=schedule_parameters) + # get execution engine + engine = get_engine(payload_module, requirements=workload.requirements()) + + with workload.allocate_inputs(execution_engine=engine) as inputs: + # prepare function arguments + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + packed_args = get_packed_arg(pointers) + + # handle to payload function + payload_func = engine.lookup(workload.payload_function_name) + + # call + payload_func(packed_args) + + if check_correctness: + workload.check_correctness(execution_engine=engine, verbose=verbose) + + +def benchmark( + workload: object, + nruns: int = 100, + nwarmup: int = 10, + schedule_parameters: Optional[dict] = None, + check_correctness: bool = True, + verbose: int = 0, +) -> np.ndarray: + # get original payload module + payload_module = workload.payload_module() + + # find payload function + payload_func = None + for op in payload_module.operation.regions[0].blocks[0]: + if ( + isinstance(op, func.FuncOp) + and op.name.value == workload.payload_function_name + ): + payload_func = op + break + assert payload_func is not None, "Could not find payload function" + payload_arguments = payload_func.type.inputs + + # emit benchmark function that calls payload and times it + with ir.InsertionPoint(payload_module.body): + # define rtclock function + f64_t = ir.F64Type.get() + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((nruns,), f64_t) + args = payload_arguments + [time_memref_t] + + @func.func(*args) + def benchmark(*args): + index_t = ir.IndexType.get() + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + nwarmup_cst = arith.constant(index_t, nwarmup) + for i in scf.for_(zero, nwarmup_cst, one): + # FIXME(upstream): func.call is broken for this use case? + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + scf.yield_(()) + nruns_cst = arith.constant(index_t, nruns) + for i in scf.for_(zero, nruns_cst, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, args[-1], [i]) + scf.yield_(()) + + benchmark.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + # lower + apply_transform_schedule( + payload_module, + workload.schedule_module(parameters=schedule_parameters), + ) + # get execution engine, rtclock requires mlir_c_runner + engine = get_engine(payload_module) + + with workload.allocate_inputs(execution_engine=engine) as inputs: + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + if check_correctness: + # call payload once to verify correctness + # prepare function arguments + packed_args = get_packed_arg(pointers) + + payload_func = engine.lookup(workload.payload_function_name) + payload_func(packed_args) + success = workload.check_correctness( + execution_engine=engine, verbose=verbose + ) + if not success: + raise ValueError("Benchmark verification failed.") + + # allocate buffer for timings and prepare arguments + time_array = np.zeros((nruns,), dtype=np.float64) + time_memref = get_ranked_memref_descriptor(time_array) + time_pointer = ctypes.pointer(ctypes.pointer(time_memref)) + packed_args_with_time = get_packed_arg(pointers + [time_pointer]) + + # call benchmark function + benchmark_func = engine.lookup("benchmark") + benchmark_func(packed_args_with_time) + + return time_array diff --git a/python/examples/xegpu_matmul/schedule.py b/python/examples/xegpu_matmul/schedule.py new file mode 100644 index 0000000..355177a --- /dev/null +++ b/python/examples/xegpu_matmul/schedule.py @@ -0,0 +1,433 @@ +from mlir import ir +from mlir.dialects.transform import loop +from mlir.dialects.transform import bufferization +from mlir.dialects.transform import xegpu +from mlir.dialects.bufferization import LayoutMapOption +from mlir.dialects import transform +from mlir.dialects.transform import structured +from mlir_utils import apply_registered_pass, match, canonicalize +from typing import Optional + + +# hardware constraints +dpas_tile = [8, 16, 16] +prefetch_inst_data = [8, 16] +nb_workitems = 16 # workitems in subgroup + + +def get_schedule_module( + has_bias: bool = False, + has_relu: bool = False, + dump_kernel: str = "", + params: Optional[dict] = None, +) -> ir.Module: + """Generate transform schedule module.""" + mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + with ir.InsertionPoint(mod.body): + named_sequence = transform.named_sequence( + "__transform_main", + [transform.AnyOpType.get()], # input types + [], # output types + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + with ir.InsertionPoint(named_sequence.body): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + xegpu_matmul_transform_schedule( + payload_mod, + has_bias=has_bias, + has_relu=has_relu, + dump_kernel=dump_kernel, + params=params, + ) + + return mod + + +def xegpu_matmul_transform_schedule( + mod: ir.Value, + has_bias: bool = False, + has_relu: bool = False, + dump_kernel: str = "", + params: Optional[dict] = None, +): + """Transform schedule for matmul-like payload.""" + mod, interrupted = bundle_xepu_matmul_schedule( + mod, + has_bias=has_bias, + has_relu=has_relu, + dump_kernel=dump_kernel, + params=params, + ) + if interrupted: + transform.yield_() + return + + mod, interrupted = bundle_xegpu_to_binary( + mod, + dump_kernel=dump_kernel, + ) + transform.yield_() + + +def bundle_xepu_matmul_schedule( + mod, + has_bias: bool = False, + has_relu: bool = False, + dump_kernel: str = "", + params: Optional[dict] = None, +): + """Schedule for lowering matmul-like payload to xegpu wg level.""" + if params is None: + raise ValueError("Schedule parameters must be provided.") + + # tunable parameters + wg_tile = [params["auto_wg_d0"], params["auto_wg_d1"]] + sg_tile = [params["auto_sg_d0"], params["auto_sg_d1"]] + k_tile = params["auto_k"] + + load_tile_a = [params["auto_load_a_d0"], params["auto_load_a_d1"]] + load_tile_b = [params["auto_load_b_d0"], params["auto_load_b_d1"]] + + prefetch_tile_a = [params["auto_prefetch_a_d0"], params["auto_prefetch_a_d1"]] + prefetch_tile_b = [params["auto_prefetch_b_d0"], params["auto_prefetch_b_d1"]] + nb_prefetch = params["auto_nb_prefetch"] + + # derived parameters + sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] + # number of threads collapsed to 1d layout + nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems + prefetch_layout_a = [ + wg_tile[0] // prefetch_tile_a[0], + k_tile // prefetch_tile_a[1], + ] + prefetch_layout_b = [ + k_tile // prefetch_tile_b[0], + wg_tile[1] // prefetch_tile_b[1], + ] + + # matmul matrix shapes + sg_tile_a = [sg_tile[0], k_tile] + sg_tile_b = [k_tile, sg_tile[1]] + + if dump_kernel == "initial": + return mod, True + + anytype = transform.AnyOpType.get() + anyvalue = transform.AnyValueType.get() + + # match the payload function + anchor = match(mod, ops={"linalg.matmul"}) + func = transform.get_parent_op( + anytype, + anchor, + op_name="func.func", + deduplicate=True, + ) + + dpas_shape_a = [dpas_tile[0], dpas_tile[2]] + dpas_shape_b = [dpas_tile[2], dpas_tile[1]] + dpas_shape_c = [dpas_tile[0], dpas_tile[1]] + + # wg tiling + if has_relu: + terminal = match(mod, ops={"linalg.max"}) + elif has_bias: + terminal = match(mod, ops={"linalg.add"}) + else: + terminal = match(mod, ops={"linalg.matmul"}) + # FIXME use structured.structured_fuse + structured.FuseOp(terminal, tile_sizes=wg_tile, use_forall=True) + transform.apply_cse(mod) + canonicalize(mod) + + # k loop tiling + wg_matmul = match(mod, ops={"linalg.matmul"}) + # FIXME use structured.structured_tile_using_for + wgk_matmul, k_loop = structured.TileUsingForOp( + wg_matmul, sizes=[0, 0, k_tile] + ).results + + transform.apply_cse(func) + canonicalize(func) + + if dump_kernel == "tiled": + return mod, True + + # vectorize + # FIXME use structured.structured_vectorize_children_and_apply_patterns + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + + # hoist loop invariant vector read/store ops + k_loop = match(func, ops={"scf.for"}) + loop.HoistLoopInvariantSubsetsOp(k_loop) + + transform.apply_cse(func) + canonicalize(func) + + if dump_kernel == "vectorized": + return mod, True + + # bufferize + + # eliminate empty tensors to avoid emitting extra copy ops + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + if dump_kernel == "bufferized": + return mod, True + + # convert forall to parallel + wg_loop = match(mod, ops={"scf.forall"}) + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert to scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set correct number of gpu threads + launch_op = match(func, ops={"gpu.launch"}) + xegpu.set_gpu_launch_threads(launch_op, threads=[nb_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # convert vector to xegpu + gpu_mod = match(mod, ops={"gpu.module"}) + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + + if dump_kernel == "xegpu-initial": + return mod, True + + # add layouts to DPAS op operands + k_loop = match(gpu_func, ops={"scf.for"}) + dpas_op = match(k_loop, ops={"xegpu.dpas"}) + tile_a = transform.get_operand(anyvalue, dpas_op, [0]) + tile_b = transform.get_operand(anyvalue, dpas_op, [1]) + tile_c = transform.get_operand(anyvalue, dpas_op, [2]) + + def convert_layout(value, input, target): + xegpu.convert_layout( + value, + input_sg_layout=input["sg_layout"], + input_sg_data=input["sg_data"], + input_inst_data=input["inst_data"], + target_sg_layout=target["sg_layout"], + target_sg_data=target["sg_data"], + target_inst_data=target["inst_data"], + ) + + # insert prefetch ops for DPAS A and B tiles + desc_prefetch_a = xegpu.insert_prefetch( + tile_a, + nb_prefetch=nb_prefetch, + ) + xegpu.set_desc_layout( + desc_prefetch_a, + sg_layout=prefetch_layout_a, + sg_data=prefetch_tile_a, + inst_data=prefetch_inst_data, + ) + desc_prefetch_b = xegpu.insert_prefetch( + tile_b, + nb_prefetch=nb_prefetch, + ) + xegpu.set_desc_layout( + desc_prefetch_b, + sg_layout=prefetch_layout_b, + sg_data=prefetch_tile_b, + inst_data=prefetch_inst_data, + ) + + # A tile load layout + layout_load_a = { + "sg_layout": sg_layout, + "sg_data": sg_tile_a, + "inst_data": load_tile_a, + } + desc_op_a = xegpu.get_desc_op(tile_a) + desc_op_a = xegpu.set_desc_layout( + target=desc_op_a, + **layout_load_a, + ) + # A tile dpas layout + layout_dpas_a = layout_load_a.copy() + layout_dpas_a["inst_data"] = dpas_shape_a + convert_layout(tile_a, layout_load_a, layout_dpas_a) + + # B tile load layout + layout_load_b = { + "sg_layout": sg_layout, + "sg_data": sg_tile_b, + "inst_data": load_tile_b, + } + desc_op_b = xegpu.get_desc_op(tile_b) + desc_op_b = xegpu.set_desc_layout( + target=desc_op_b, + **layout_load_b, + ) + # B tile dpas layout + layout_dpas_b = layout_load_b.copy() + layout_dpas_b["inst_data"] = dpas_shape_b + convert_layout(tile_b, layout_load_b, layout_dpas_b) + + # C tile layout + output_layout = { + "sg_layout": sg_layout, + "sg_data": sg_tile, + "inst_data": dpas_shape_c, + } + desc_op_c = xegpu.get_desc_op(tile_c) + desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) + # C tile dpas layout + xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout) + + if has_relu: + # for post ops we need to add C layout manually + max_op = match(gpu_func, ops={"arith.maximumf"}) + xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout) + # find zero constant buffer and annotate it + const_buffer = transform.get_producer_of_operand(anytype, max_op, 1) + xegpu.set_op_layout_attr(const_buffer, result=True, index=0, **output_layout) + if has_bias: + # for post ops we need to add C layout manually + add_op = match(gpu_func, ops={"arith.addf"}) + xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout) + + # annotate broadcast op operands + bcast_op = transform.get_producer_of_operand(anytype, add_op, 0) + xegpu.set_op_layout_attr(bcast_op, result=True, index=0, **output_layout) + bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0) + xegpu.set_op_layout_attr( + bcast_load, result=True, index=0, **output_layout, slice_dims=[0] + ) + output_layout_dim1 = { + "sg_layout": [sg_layout[1]], + "sg_data": [sg_tile[1]], + "inst_data": [dpas_shape_c[1]], + } + offset = transform.get_producer_of_operand(anytype, bcast_load, 1) + xegpu.set_op_layout_attr(offset, result=True, index=0, **output_layout_dim1) + aux1 = transform.get_producer_of_operand(anytype, offset, 0) + xegpu.set_op_layout_attr(aux1, result=True, index=0, **output_layout_dim1) + aux2 = transform.get_producer_of_operand(anytype, offset, 1) + xegpu.set_op_layout_attr(aux2, result=True, index=0, **output_layout_dim1) + mask = transform.get_producer_of_operand(anytype, bcast_load, 2) + xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1) + raise NotImplementedError("Bias layout propagation is not supported.") + transform.apply_cse(gpu_func) + canonicalize(gpu_func) + + # hoist desc ops out of reduction loop + transform.apply_licm(k_loop) + + canonicalize(gpu_func) + transform.apply_cse(gpu_func) + + if dump_kernel == "xegpu-wg": + return mod, True + + return mod, False + + +def bundle_xegpu_to_binary(mod, dump_kernel: str = ""): + """Schedule for lowering xegpu wg level to binary.""" + # This schedule corresponds to upstream MLIR XeVM lowering pipeline + # and is payload independent. + + # TODO applying gpu-lower-to-xevm-pipeline pass affects performance + # mod = apply_registered_pass( + # mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} + # ) + + gpu_mod = match(mod, ops={"gpu.module"}) + # xegpu distribution + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "xegpu-wg-to-sg-distribute") + transform.apply_cse(gpu_func) + + if dump_kernel == "xegpu-sg": + return mod, True + + gpu_func = apply_registered_pass(gpu_func, "lower-affine") + transform.apply_cse(gpu_func) + gpu_func = apply_registered_pass(gpu_func, "xegpu-blocking") + canonicalize(gpu_func) + transform.apply_cse(gpu_func) + + if dump_kernel == "xegpu-inst": + return mod, True + + gpu_func = apply_registered_pass(gpu_func, "xegpu-propagate-layout") + gpu_mod = apply_registered_pass(gpu_mod, "xegpu-subgroup-distribute") + canonicalize(gpu_mod) + transform.apply_cse(gpu_mod) + gpu_mod = apply_registered_pass(gpu_mod, "loop-invariant-code-motion") + transform.apply_cse(gpu_mod) + gpu_mod = apply_registered_pass(gpu_mod, "xegpu-vector-linearize") + gpu_mod = apply_registered_pass(gpu_mod, "convert-xegpu-to-xevm") + gpu_mod = apply_registered_pass( + gpu_mod, "convert-gpu-to-llvm-spv", options={"use-64bit-index": "true"} + ) + gpu_mod = apply_registered_pass(gpu_mod, "convert-xevm-to-llvm") + transform.apply_cse(gpu_mod) + + func = match(mod, ops={"func.func"}) + func = apply_registered_pass(func, "gpu-async-region") + + mod = apply_registered_pass(mod, "reconcile-unrealized-casts") + mod = apply_registered_pass(mod, "convert-vector-to-scf") + mod = apply_registered_pass(mod, "convert-scf-to-cf") + mod = apply_registered_pass(mod, "expand-strided-metadata") + mod = apply_registered_pass(mod, "finalize-memref-to-llvm") + mod = apply_registered_pass(mod, "convert-cf-to-llvm") + mod = apply_registered_pass(mod, "convert-vector-to-llvm") + mod = apply_registered_pass(mod, "convert-arith-to-llvm") + mod = apply_registered_pass(mod, "convert-index-to-llvm") + mod = apply_registered_pass(mod, "convert-func-to-llvm") + mod = apply_registered_pass(mod, "convert-math-to-llvm") + mod = apply_registered_pass(mod, "gpu-to-llvm") + mod = apply_registered_pass(mod, "lower-affine") + mod = apply_registered_pass(mod, "reconcile-unrealized-casts") + transform.apply_cse(mod) + mod = apply_registered_pass(mod, "gpu-module-to-binary") + + return mod, False