From 8e04420cd57a0b41ddb77a77c4049bbad2804d73 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 13 Nov 2025 20:31:01 +0200 Subject: [PATCH 01/16] add xegpu matmul example --- python/examples/xegpu_matmul/README.md | 84 +++ .../examples/xegpu_matmul/execution_engine.py | 195 +++++++ python/examples/xegpu_matmul/matmul.py | 401 +++++++++++++++ python/examples/xegpu_matmul/mlir_utils.py | 33 ++ .../xegpu_matmul/payload_generator.py | 134 +++++ python/examples/xegpu_matmul/schedule.py | 486 ++++++++++++++++++ 6 files changed, 1333 insertions(+) create mode 100644 python/examples/xegpu_matmul/README.md create mode 100644 python/examples/xegpu_matmul/execution_engine.py create mode 100644 python/examples/xegpu_matmul/matmul.py create mode 100644 python/examples/xegpu_matmul/mlir_utils.py create mode 100644 python/examples/xegpu_matmul/payload_generator.py create mode 100644 python/examples/xegpu_matmul/schedule.py diff --git a/python/examples/xegpu_matmul/README.md b/python/examples/xegpu_matmul/README.md new file mode 100644 index 0000000..0e4c96e --- /dev/null +++ b/python/examples/xegpu_matmul/README.md @@ -0,0 +1,84 @@ +# XeGPU matrix multiplication benchmark + +## Installation + +### 1. GPU Drivers and Level Zero + +Install Intel GPU drivers and Level Zero runtime on your system. + +### 2. Compile LLVM with Intel GPU support + +To use Lighthouse with Intel GPUs, LLVM must be built with LevelZero runtime. + +Set up a Python environment and install Python packages: + +```bash +pip install pybind11 nanobind PyYAML numpy +``` + +Set `LLVM_INSTALL_DIR` and use the below script to checkout and compile LLVM locally. + +```bash +export LLVM_INSTALL_DIR=<...> +LLVM_VERSION=83765f435d1c +git checkout https://github.com/llvm/llvm-project.git -b $LLVM_VERSION + +cd llvm-project +mkdir -p build +cd build + +cmake ../llvm -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_BUILD_EXAMPLES=OFF \ + -DLLVM_TARGETS_TO_BUILD="host" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" \ + -DLLVM_INSTALL_GTEST=ON \ + -DMLIR_ENABLE_LEVELZERO_RUNNER=1 \ + -DMLIR_ENABLE_BINDINGS_PYTHON=1 \ + -DPython3_EXECUTABLE=$(which python3) \ + -DLLVM_INSTALL_UTILS=ON \ + -DCMAKE_INSTALL_PREFIX=${INSTALL_DIR} +cmake --build . +cmake --install . +``` + +If cmake cannot find LevelZero, set environment variable `LEVEL_ZERO_DIR=`. + +### Install Lighthouse + +Install Lighthouse as instructed in the main [README](../../../../README.md). + +Override the default LLVM package by setting `PYTHONPATH` to the local LLVM Python bindings: + +```bash +export PYTHONPATH=${LLVM_INSTALL_DIR}/python_packages/mlir_core +``` + +## Usage + +Run the default 4k (float16, float16) -> float32 matrix multiplication benchmark with correctness test: + +```bash +python matmul.py --check-result +``` + +Set different M, N, K problem size + +```bash +python matmul.py --sizes 1024 2048 4096 ... +``` + +Run with ReLU post-op: + +```bash +python matmul.py --relu ... +``` + +See all command line arguments: + +```bash +python matmul.py --help +``` diff --git a/python/examples/xegpu_matmul/execution_engine.py b/python/examples/xegpu_matmul/execution_engine.py new file mode 100644 index 0000000..562f659 --- /dev/null +++ b/python/examples/xegpu_matmul/execution_engine.py @@ -0,0 +1,195 @@ +import numpy as np +import ctypes +import os +from typing import Optional + +from mlir.dialects.transform import interpreter as transform_interpreter +from mlir.dialects import func, arith, scf, memref +from mlir.execution_engine import ExecutionEngine +from mlir import ir +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor + +from lighthouse.utils import get_packed_arg +from mlir_utils import get_mlir_library_path + + +def get_engine(payload_module, opt_level=3) -> ExecutionEngine: + context = ir.Context() + location = ir.Location.unknown(context) + lib_dir = get_mlir_library_path() + libs = [ + "libmlir_levelzero_runtime.so", + "libmlir_runner_utils.so", + "libmlir_c_runner_utils.so", + ] + libs = [os.path.join(lib_dir, lib) for lib in libs] + with context, location: + execution_engine = ExecutionEngine( + payload_module, opt_level=opt_level, shared_libs=libs + ) + execution_engine.initialize() + return execution_engine + + +def apply_transform_schedule( + payload_module, + schedule_module, + context, + location, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, +): + if not dump_kernel or dump_kernel != "initial": + with context, location: + # invoke transform interpreter directly + transform_interpreter.apply_named_sequence( + payload_root=payload_module, + transform_root=schedule_module.body.operations[0], + transform_module=schedule_module, + ) + if dump_kernel: + print(payload_module) + if dump_schedule: + print(schedule_module) + + +def lower_payload( + workload, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, + schedule_parameters: Optional[dict] = None, +) -> ir.Module: + payload_module = workload.payload_module() + schedule_module = workload.schedule_module( + dump_kernel=dump_kernel, parameters=schedule_parameters + ) + apply_transform_schedule( + payload_module, + schedule_module, + workload.context, + workload.location, + dump_kernel=dump_kernel, + dump_schedule=dump_schedule, + ) + return payload_module + + +def execute( + workload, + check_correctness: bool = True, + schedule_parameters: Optional[dict] = None, + verbose: int = 0, +): + # lower payload with schedule + payload_module = lower_payload(workload, schedule_parameters=schedule_parameters) + # get execution engine + engine = get_engine(payload_module, requirements=workload.requirements()) + + with workload.allocate(execution_engine=engine): + # prepare function arguments + inputs = workload.get_input_arrays(execution_engine=engine) + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + packed_args = get_packed_arg(pointers) + + # handle to payload function + payload_func = engine.lookup(workload.payload_function_name) + + # call + payload_func(packed_args) + + if check_correctness: + workload.check_correctness(execution_engine=engine, verbose=verbose) + + +def benchmark( + workload, + nruns: int = 100, + nwarmup: int = 10, + schedule_parameters: Optional[dict] = None, + check_correctness: bool = True, + verbose: int = 0, +) -> np.ndarray: + # get original payload module + payload_module = workload.payload_module() + + # find payload function + payload_func = None + for op in payload_module.operation.regions[0].blocks[0]: + if ( + isinstance(op, func.FuncOp) + and str(op.name).strip('"') == workload.payload_function_name + ): + payload_func = op + break + assert payload_func is not None, "Could not find payload function" + payload_arguments = payload_func.type.inputs + + # emit benchmark function that calls payload and times it + with workload.context, workload.location: + with ir.InsertionPoint(payload_module.body): + # define rtclock function + f64_t = ir.F64Type.get() + f = func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((nruns,), f64_t) + args = payload_arguments + [time_memref_t] + f = func.FuncOp("benchmark", (tuple(args), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + index_t = ir.IndexType.get() + zero = arith.ConstantOp(index_t, 0) + one = arith.ConstantOp(index_t, 1) + nwarmup_cst = arith.ConstantOp(index_t, nwarmup) + for_op = scf.ForOp(zero, nwarmup_cst, one) + with ir.InsertionPoint(for_op.body): + func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) + scf.YieldOp(()) + nruns_cst = arith.ConstantOp(index_t, nruns) + for_op = scf.ForOp(zero, nruns_cst, one) + i = for_op.induction_variable + with ir.InsertionPoint(for_op.body): + tic = func.CallOp((f64_t,), "rtclock", ()).result + func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) + toc = func.CallOp((f64_t,), "rtclock", ()).result + time = arith.SubFOp(toc, tic) + memref.StoreOp(time, f.arguments[-1], [i]) + scf.YieldOp(()) + func.ReturnOp(()) + + # lower + apply_transform_schedule( + payload_module, + workload.schedule_module(parameters=schedule_parameters), + workload.context, + workload.location, + ) + # get execution engine, rtclock requires mlir_c_runner + engine = get_engine(payload_module) + + with workload.allocate(execution_engine=engine): + inputs = workload.get_input_arrays(execution_engine=engine) + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + if check_correctness: + # call payload once to verify correctness + # prepare function arguments + packed_args = get_packed_arg(pointers) + + payload_func = engine.lookup(workload.payload_function_name) + payload_func(packed_args) + success = workload.check_correctness( + execution_engine=engine, verbose=verbose + ) + if not success: + raise ValueError("Benchmark verification failed.") + + # allocate buffer for timings and prepare arguments + time_array = np.zeros((nruns,), dtype=np.float64) + time_memref = get_ranked_memref_descriptor(time_array) + time_pointer = ctypes.pointer(ctypes.pointer(time_memref)) + packed_args_with_time = get_packed_arg(pointers + [time_pointer]) + + # call benchmark function + benchmark_func = engine.lookup("benchmark") + benchmark_func(packed_args_with_time) + + return time_array diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py new file mode 100644 index 0000000..2a2985b --- /dev/null +++ b/python/examples/xegpu_matmul/matmul.py @@ -0,0 +1,401 @@ +""" +XeGPU matrix multiplication benchmark. +""" + +import numpy as np +from mlir import ir +from mlir.runtime.np_to_memref import ( + get_ranked_memref_descriptor, + make_nd_memref_descriptor, + as_ctype, +) +import ctypes +from contextlib import contextmanager +from functools import cached_property +from lighthouse.utils import get_packed_arg, memref_to_ctype +from schedule import get_schedule_module +from payload_generator import generate_matmul_payload + +from execution_engine import lower_payload, benchmark +import argparse + + +def numpy_to_ctype(arr) -> ctypes._Pointer: + """Convert numpy array to memref and ctypes **void pointer.""" + return memref_to_ctype(get_ranked_memref_descriptor(arr)) + + +class XeGPUMatMul: + """ + Matrix multiplication workload on XeGPU. + + Computes C = A * B for input matrices A (M x K) and B (K x N). + + Optionally adds a ReLU operation on the result C. + Optionally adds a bias term to C (not implemented yet). + """ + + payload_function_name = "payload" + + def __init__( + self, M, N, K, ab_type="f32", c_type="f32", has_bias=False, has_relu=False + ): + self.M = M + self.N = N + self.K = K + self.a_shape = (M, K) + self.b_shape = (K, N) + self.c_shape = (M, N) + self.ab_type = ab_type + self.c_type = c_type + type_str_to_numpy = { + "f16": np.float16, + "f32": np.float32, + } + self.ab_dtype = type_str_to_numpy[ab_type] + self.c_dtype = type_str_to_numpy[c_type] + self.has_bias = has_bias + self.has_relu = has_relu + if has_bias: + raise NotImplementedError("Bias is not implemented yet") + self.context = ir.Context() + self.location = ir.Location.unknown(context=self.context) + # cache allocated memrefs + self.gpu_memrefs = {} + + def _allocate_array( + self, name, shape, dtype_str, execution_engine + ) -> ctypes.Structure: + key = (name, dtype_str) + if key in self.gpu_memrefs: + return self.gpu_memrefs[key] + dtype = { + "f16": np.float16, + "f32": np.float32, + }[dtype_str] + alloc_func = execution_engine.lookup("gpu_alloc_" + dtype_str) + mref = make_nd_memref_descriptor(len(shape), as_ctype(dtype))() + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] + alloc_func(get_packed_arg([ptr_mref] + ptr_dims)) + self.gpu_memrefs[key] = mref + return mref + + def _allocate_inputs(self, execution_engine): + self._allocate_array("A", self.a_shape, self.ab_type, execution_engine) + self._allocate_array("B", self.b_shape, self.ab_type, execution_engine) + self._allocate_array("C", self.c_shape, self.c_type, execution_engine) + + def _deallocate_all(self, execution_engine): + for (_, dtype_str), mref in self.gpu_memrefs.items(): + dealloc_func = execution_engine.lookup("gpu_dealloc_" + dtype_str) + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + dealloc_func(get_packed_arg([ptr_mref])) + self.gpu_memrefs = {} + + @contextmanager + def allocate(self, execution_engine): + try: + self._allocate_inputs(execution_engine) + yield None + finally: + self._deallocate_all(execution_engine) + + @cached_property + def _initial_host_arrays(self) -> list[np.ndarray]: + """Generate initial values on host with numpy.""" + + # use integer values to avoid f16/f32 floating point discrepancies + def gen_random(shape, dtype): + # generate values in range [-3, 3] + a = np.round(6 * np.random.random_sample(shape)) - 3 + return a.astype(dtype) + + np.random.seed(2) + A = gen_random((self.M, self.K), self.ab_dtype) + B = gen_random((self.K, self.N), self.ab_dtype) + C = gen_random((self.M, self.N), self.c_dtype) + return A, B, C + + @cached_property + def _reference_solution(self) -> np.ndarray: + """Compute reference solution on host with numpy.""" + A, B, C = self._initial_host_arrays + # use float32 data type for efficiency + f32 = np.float32 + C_ref = A.astype(f32) @ B.astype(f32) + C.astype(f32) + if self.has_relu: + C_ref = np.maximum(C_ref, 0) + if self.has_bias: + raise NotImplementedError("Bias verification not implemented") + return C_ref + + def get_input_arrays(self, execution_engine) -> list[ctypes.Structure]: + A_gpu = self._allocate_array("A", self.a_shape, self.ab_type, execution_engine) + B_gpu = self._allocate_array("B", self.b_shape, self.ab_type, execution_engine) + C_gpu = self._allocate_array("C", self.c_shape, self.c_type, execution_engine) + + A_host, B_host, C_host = self._initial_host_arrays + # copy initial values to device + copy_func_ab = execution_engine.lookup("gpu_copy_" + self.ab_type) + copy_func_c = execution_engine.lookup("gpu_copy_" + self.c_type) + copy_func_ab(get_packed_arg([numpy_to_ctype(A_host), memref_to_ctype(A_gpu)])) + copy_func_ab(get_packed_arg([numpy_to_ctype(B_host), memref_to_ctype(B_gpu)])) + copy_func_c(get_packed_arg([numpy_to_ctype(C_host), memref_to_ctype(C_gpu)])) + + # return memrefs for the payload function + return [A_gpu, B_gpu, C_gpu] + + def check_correctness(self, execution_engine, verbose: int = 0) -> bool: + # copy result from device to host + C_gpu = self.gpu_memrefs[("C", self.c_type)] + C_host_copy = np.zeros((self.M, self.N), dtype=self.c_dtype) + copy_func = execution_engine.lookup("gpu_copy_" + self.c_type) + copy_func(get_packed_arg([memref_to_ctype(C_gpu), numpy_to_ctype(C_host_copy)])) + + C_host_ref = self._reference_solution + C_host = C_host_copy.astype(np.float32) + if verbose > 1: + print("Reference solution:") + print(C_host_ref) + print("Computed solution:") + print(C_host) + success = np.allclose(C_host, C_host_ref) + + if verbose: + if success: + print("PASSED") + else: + print("FAILED Result mismatch!") + return success + + def get_complexity(self) -> tuple[int, int, int]: + M, N, K = self.M, self.N, self.K + flop_count = 2 * M * N * K + if self.has_bias: + flop_count += M * N + if self.has_relu: + flop_count += M * N + nbytes_ab = np.dtype(self.ab_dtype).itemsize + nbytes_c = np.dtype(self.c_dtype).itemsize + memory_reads = (M * K + K * N) * nbytes_ab # read A and B + memory_writes = M * N * nbytes_c # write C + return (flop_count, memory_reads, memory_writes) + + def payload_module(self) -> ir.Module: + mod = generate_matmul_payload( + func_name=self.payload_function_name, + M=self.M, + N=self.N, + K=self.K, + ab_type_str=self.ab_type, + c_type_str=self.c_type, + has_bias=self.has_bias, + has_relu=self.has_relu, + context=self.context, + location=self.location, + ) + return mod + + def schedule_module(self, dump_kernel=None, parameters=None) -> ir.Module: + return get_schedule_module( + context=self.context, + location=self.location, + has_bias=self.has_bias, + has_relu=self.has_relu, + dump_kernel=dump_kernel, + params=parameters, + ) + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Matrix Multiplication using MLIR", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--sizes", + type=int, + nargs=3, + default=[4096, 4096, 4096], + help="M,N,K matrix sizes (A=MxK, B=KxN, C=MxN).", + ) + parser.add_argument( + "--wg-tile", + type=int, + nargs=2, + default=[256, 256], + help="Workgroup tile size M,N.", + ) + parser.add_argument( + "--sg-tile", + type=int, + nargs=2, + default=[32, 32], + help="Subgroup tile size M,N.", + ) + parser.add_argument( + "--k-tile", + type=int, + default=32, + help="Inner reduction dimension tile size K.", + ) + parser.add_argument( + "--load-tile-a", + type=int, + nargs=2, + default=[32, 16], + help="Tile size for loading A matrix for DPAS op.", + ) + parser.add_argument( + "--load-tile-b", + type=int, + nargs=2, + default=[32, 16], + help="Tile size for loading B matrix for DPAS op.", + ) + parser.add_argument( + "--prefetch-tile-a", + type=int, + nargs=2, + default=[8, 32], + help="Tile size for cooperative prefetching of subgroup A matrix", + ) + parser.add_argument( + "--prefetch-tile-b", + type=int, + nargs=2, + default=[8, 16], + help="Tile size for cooperative prefetching of subgroup B matrix", + ) + parser.add_argument( + "--nb-prefetch", + type=int, + default=1, + help="Number of initial prefetches.", + ) + parser.add_argument( + "--ab-type", + type=str, + choices=["f16", "f32"], + default="f16", + help="Data type of A and B matrices.", + ) + parser.add_argument( + "--c-type", + type=str, + choices=["f16", "f32"], + default="f32", + help="Data type of the C matrix.", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--relu", + action="store_true", + help="Add relu op after the matrix multiplication (and bias if any).", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the matrix multiplication.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "tiled", + "vectorized", + "bufferized", + "xegpu-initial", + "xegpu-wg", + "xegpu-sg", + "xegpu-inst", + "final", + ], + help="Dump kernel IR at different stages of lowering.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "auto_wg_d0": args.wg_tile[0], + "auto_wg_d1": args.wg_tile[1], + "auto_sg_d0": args.sg_tile[0], + "auto_sg_d1": args.sg_tile[1], + "auto_k": args.k_tile, + "auto_load_a_d0": args.load_tile_a[0], + "auto_load_a_d1": args.load_tile_a[1], + "auto_load_b_d0": args.load_tile_b[0], + "auto_load_b_d1": args.load_tile_b[1], + "auto_prefetch_a_d0": args.prefetch_tile_a[0], + "auto_prefetch_a_d1": args.prefetch_tile_a[1], + "auto_prefetch_b_d0": args.prefetch_tile_b[0], + "auto_prefetch_b_d1": args.prefetch_tile_b[1], + "auto_nb_prefetch": args.nb_prefetch, + } + + M, N, K = args.sizes + wload = XeGPUMatMul( + M=M, + N=N, + K=K, + ab_type=args.ab_type, + c_type=args.c_type, + has_bias=False, + has_relu=args.relu, + ) + + if args.dump_kernel or args.dump_schedule: + lower_payload( + wload, + dump_kernel=args.dump_kernel, + dump_schedule=args.dump_schedule, + schedule_parameters=params, + ) + else: + times = benchmark( + wload, + nruns=args.nruns, + nwarmup=15, + check_correctness=args.check_result, + schedule_parameters=params, + verbose=1, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + def list2str(a): + return ",".join(map(str, a)) + + parts = [ + f"sizes={list2str(args.sizes)}", + f"dt={args.ab_type},{args.c_type}", + f"wg-tile={list2str(args.wg_tile)}", + f"sg-tile={list2str(args.sg_tile)}", + f"k-tile={args.k_tile}", + f"load-a-tile={list2str(args.load_tile_a)}", + f"load-b-tile={list2str(args.load_tile_b)}", + f"pf-a-tile={list2str(args.prefetch_tile_a)}", + f"pf-b-tile={list2str(args.prefetch_tile_b)}", + f"time(us): {elapsed:.2f}", + f"GFLOPS: {gflops:.2f}", + ] + print(" ".join(parts)) diff --git a/python/examples/xegpu_matmul/mlir_utils.py b/python/examples/xegpu_matmul/mlir_utils.py new file mode 100644 index 0000000..863c337 --- /dev/null +++ b/python/examples/xegpu_matmul/mlir_utils.py @@ -0,0 +1,33 @@ +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured +import os + + +def apply_registered_pass(*args, **kwargs): + return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs) + + +def match(*args, **kwargs): + return structured.MatchOp(transform.AnyOpType.get(), *args, **kwargs) + + +def cse(op): + transform.ApplyCommonSubexpressionEliminationOp(op) + + +def canonicalize(op): + with ir.InsertionPoint(transform.ApplyPatternsOp(op).patterns): + transform.ApplyCanonicalizationPatternsOp() + + +def get_mlir_library_path(): + pkg_path = ir.__file__ + if "python_packages" in pkg_path: + # looks like a local mlir install + path = pkg_path.split("python_packages")[0] + os.sep + "lib" + else: + # maybe installed in python path + path = os.path.split(pkg_path)[0] + os.sep + "_mlir_libs" + assert os.path.isdir(path) + return path diff --git a/python/examples/xegpu_matmul/payload_generator.py b/python/examples/xegpu_matmul/payload_generator.py new file mode 100644 index 0000000..0f4d7b9 --- /dev/null +++ b/python/examples/xegpu_matmul/payload_generator.py @@ -0,0 +1,134 @@ +from mlir import ir +from mlir.dialects import func, linalg, gpu, bufferization, arith, tensor +from typing import Optional + + +def emit_gpu_alloc(mod, suffix, element_type, rank=2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + index_t = ir.IndexType.get() + i32_t = ir.IntegerType.get_signless(32) + with ir.InsertionPoint(mod.body): + f = func.FuncOp("gpu_alloc_" + suffix, (rank * (i32_t,), (memref_dyn_t,))) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + dims = [ + arith.IndexCastOp(index_t, f.arguments[0]), + arith.IndexCastOp(index_t, f.arguments[1]), + ] + alloc = gpu.alloc(memref_dyn_t, None, [], dims, []) + func.ReturnOp((alloc,)) + + +def emit_gpu_dealloc(mod, suffix, element_type, rank=2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + with ir.InsertionPoint(mod.body): + f = func.FuncOp("gpu_dealloc_" + suffix, ((memref_dyn_t,), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + gpu.dealloc(None, [], f.arguments[0]) + func.ReturnOp(()) + + +def emit_gpu_copy(mod, suffix, element_type, rank=2): + """Emit GPU copy function.""" + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + with ir.InsertionPoint(mod.body): + f = func.FuncOp("gpu_copy_" + suffix, ((memref_dyn_t, memref_dyn_t), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + src = f.arguments[0] + dst = f.arguments[1] + gpu.memcpy(None, [], dst, src) + func.ReturnOp(()) + + +def emit_gpu_util_funcs(mod, element_type): + """Emit GPU utility functions for allocation, deallocation and copy.""" + suffix = { + ir.F16Type.get(): "f16", + ir.F32Type.get(): "f32", + }[element_type] + emit_gpu_alloc(mod, suffix, element_type) + emit_gpu_dealloc(mod, suffix, element_type) + emit_gpu_copy(mod, suffix, element_type) + + +def generate_matmul_payload( + func_name: str, + M: int, + N: int, + K: int, + ab_type_str: str, + c_type_str: str, + has_bias: bool, + has_relu: bool, + context: Optional[ir.Context] = None, + location: Optional[ir.Location] = None, +) -> ir.Module: + """Generate payload function module.""" + if context is None: + context = ir.Context() + if location is None: + location = ir.Location.unknown(context) + + with context, location: + get_ir_dtype = { + "f16": ir.F16Type.get(), + "f32": ir.F32Type.get(), + } + ab_type = get_ir_dtype[ab_type_str] + c_type = get_ir_dtype[c_type_str] + tensor_a_t = ir.RankedTensorType.get((M, K), ab_type) + tensor_b_t = ir.RankedTensorType.get((K, N), ab_type) + tensor_c_t = ir.RankedTensorType.get((M, N), c_type) + tensor_bias_t = ir.RankedTensorType.get((N,), c_type) + memref_a_t = ir.MemRefType.get((M, K), ab_type) + memref_b_t = ir.MemRefType.get((K, N), ab_type) + memref_c_t = ir.MemRefType.get((M, N), c_type) + memref_bias_t = ir.MemRefType.get((N,), c_type) + mod = ir.Module.create() + with ir.InsertionPoint(mod.body): + fargs = [memref_a_t, memref_b_t] + if has_bias: + fargs.append(memref_bias_t) + fargs.append(memref_c_t) + f = func.FuncOp(func_name, (tuple(fargs), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + A = f.arguments[0] + B = f.arguments[1] + C = f.arguments[-1] + a_tensor = bufferization.ToTensorOp(tensor_a_t, A, restrict=True) + b_tensor = bufferization.ToTensorOp(tensor_b_t, B, restrict=True) + c_tensor = bufferization.ToTensorOp( + tensor_c_t, C, restrict=True, writable=True + ) + mmul = linalg.matmul(a_tensor, b_tensor, outs=[c_tensor]) + terminal = mmul + if has_bias: + bias = f.arguments[2] + bias_tensor = bufferization.ToTensorOp( + tensor_bias_t, bias, restrict=True, writable=True + ) + empty = tensor.empty((M, N), c_type) + bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) + terminal = linalg.add(bcast, terminal, outs=[empty]) + if has_relu: + zero = arith.constant(c_type, 0.0) + empty = tensor.empty((M, N), c_type) + zero_tensor = linalg.fill(zero, outs=[empty]) + terminal = linalg.max(terminal, zero_tensor, outs=[empty]) + + bufferization.MaterializeInDestinationOp( + None, terminal, C, restrict=True, writable=True + ) + func.ReturnOp(()) + + emit_gpu_util_funcs(mod, ab_type) + if c_type != ab_type: + emit_gpu_util_funcs(mod, c_type) + + return mod diff --git a/python/examples/xegpu_matmul/schedule.py b/python/examples/xegpu_matmul/schedule.py new file mode 100644 index 0000000..7aba91b --- /dev/null +++ b/python/examples/xegpu_matmul/schedule.py @@ -0,0 +1,486 @@ +from mlir import ir +from mlir.dialects.transform import loop +from mlir.dialects.transform import bufferization +from mlir.dialects.transform import xegpu +from mlir.dialects.bufferization import LayoutMapOption +from mlir.dialects import transform +from mlir.dialects.transform import structured +from mlir_utils import apply_registered_pass, match, cse, canonicalize + + +# hardware constraints +dpas_tile = [8, 16, 16] +prefetch_inst_data = [8, 16] +nb_workitems = 16 # workitems in subgroup + + +def get_schedule_module( + has_bias=False, + has_relu=False, + dump_kernel="", + params=None, + context=None, + location=None, +) -> ir.Module: + """Generate transform schedule module.""" + if context is None: + context = ir.Context() + if location is None: + location = ir.Location.unknown(context) + + with context, location: + mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + with ir.InsertionPoint(mod.body): + named_sequence = transform.NamedSequenceOp( + "__transform_main", + [transform.AnyOpType.get()], # input types + [], # output types + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + with ir.InsertionPoint(named_sequence.body): + xegpu_matmul_transform_schedule( + named_sequence, + has_bias=has_bias, + has_relu=has_relu, + dump_kernel=dump_kernel, + params=params, + ) + # placeholder for parameter division op + i32 = ir.IntegerType.get_signless(32) + paramInt32Type = transform.ParamType.get(i32) + div_named_sequence = transform.NamedSequenceOp( + "param_div", + [paramInt32Type, paramInt32Type], # input types + [paramInt32Type], # output types + arg_attrs=[ + {"transform.readonly": ir.UnitAttr.get()}, + {"transform.readonly": ir.UnitAttr.get()}, + ], + ) + with ir.InsertionPoint(div_named_sequence.body): + p = transform.ParamConstantOp( + paramInt32Type, ir.IntegerAttr.get(i32, 1) + ) + transform.YieldOp(p) + + return mod + + +def xegpu_matmul_transform_schedule( + named_sequence, + has_bias=False, + has_relu=False, + dump_kernel="", + params=None, +): + """Transform schedule for matmul-like payload.""" + mod = bundle_header(named_sequence) + mod, interrupted = bundle_xepu_matmul_schedule( + mod, + has_bias=has_bias, + has_relu=has_relu, + dump_kernel=dump_kernel, + params=params, + ) + if interrupted: + transform.YieldOp() + return + + mod, interrupted = bundle_xegpu_to_binary( + mod, + dump_kernel=dump_kernel, + ) + transform.YieldOp() + + +def bundle_header(named_sequence): + """Matches the payload module.""" + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + return mod + + +def geo_range(start, stop, factor): + """ + Returns a geometric range dict attribute. + + `stop` is inclusive. + """ + i32 = ir.IntegerType.get_signless(32) + return ir.DictAttr.get( + { + "start": ir.IntegerAttr.get(i32, start), + "stop": ir.IntegerAttr.get(i32, stop + 1), + "factor": ir.IntegerAttr.get(i32, factor), + } + ) + + +def lin_range(start, stop, step): + """ + Returns a linear range dict attribute. + + `stop` is inclusive. + """ + i32 = ir.IntegerType.get_signless(32) + return ir.DictAttr.get( + { + "start": ir.IntegerAttr.get(i32, start), + "stop": ir.IntegerAttr.get(i32, stop + 1), + "step": ir.IntegerAttr.get(i32, step), + } + ) + + +def bundle_xepu_matmul_schedule( + mod, + has_bias=False, + has_relu=False, + dump_kernel="", + params=None, +): + """Schedule for lowering matmul-like payload to xegpu wg level.""" + if params is None: + raise ValueError("Schedule parameters must be provided.") + + # tunable parameters + wg_tile = [params["auto_wg_d0"], params["auto_wg_d1"]] + sg_tile = [params["auto_sg_d0"], params["auto_sg_d1"]] + k_tile = params["auto_k"] + + load_tile_a = [params["auto_load_a_d0"], params["auto_load_a_d1"]] + load_tile_b = [params["auto_load_b_d0"], params["auto_load_b_d1"]] + + prefetch_tile_a = [params["auto_prefetch_a_d0"], params["auto_prefetch_a_d1"]] + prefetch_tile_b = [params["auto_prefetch_b_d0"], params["auto_prefetch_b_d1"]] + nb_prefetch = params["auto_nb_prefetch"] + + # derived parameters + sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] + # number of threads collapsed to 1d layout + nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems + prefetch_layout_a = [ + wg_tile[0] // prefetch_tile_a[0], + k_tile // prefetch_tile_a[1], + ] + prefetch_layout_b = [ + k_tile // prefetch_tile_b[0], + wg_tile[1] // prefetch_tile_b[1], + ] + + # matmul matrix shapes + sg_tile_a = [sg_tile[0], k_tile] + sg_tile_b = [k_tile, sg_tile[1]] + + if dump_kernel == "initial": + return mod, True + + anytype = transform.AnyOpType.get() + anyvalue = transform.AnyValueType.get() + + # match the payload function + anchor = match(mod, ops={"linalg.matmul"}) + func = transform.get_parent_op( + anytype, + anchor, + op_name="func.func", + deduplicate=True, + ) + + dpas_shape_a = [dpas_tile[0], dpas_tile[2]] + dpas_shape_b = [dpas_tile[2], dpas_tile[1]] + dpas_shape_c = [dpas_tile[0], dpas_tile[1]] + + # wg tiling + if has_relu: + terminal = match(mod, ops={"linalg.max"}).result + elif has_bias: + terminal = match(mod, ops={"linalg.add"}).result + else: + terminal = match(mod, ops={"linalg.matmul"}).result + structured.FuseOp(terminal, tile_sizes=wg_tile, use_forall=True) + cse(mod) + canonicalize(mod) + + # k loop tiling + wg_matmul = match(mod, ops={"linalg.matmul"}).result + wgk_matmul, k_loop = structured.TileUsingForOp( + wg_matmul, sizes=[0, 0, k_tile] + ).results + + cse(func) + canonicalize(func) + + if dump_kernel == "tiled": + return mod, True + + # vectorize + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + + # hoist loop invariant vector read/store ops + k_loop = match(func, ops={"scf.for"}) + loop.HoistLoopInvariantSubsetsOp(k_loop) + + cse(func) + canonicalize(func) + + if dump_kernel == "vectorized": + return mod, True + + # bufferize + + # eliminate empty tensors to avoid emitting extra copy ops + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + cse(mod) + canonicalize(mod) + + if dump_kernel == "bufferized": + return mod, True + + # convert forall to parallel + wg_loop = match(mod, ops={"scf.forall"}) + wg_loop = loop.ForallToParallelOp([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert to scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + cse(func) + canonicalize(func) + + # set correct number of gpu threads + launch_op = match(func, ops={"gpu.launch"}) + xegpu.set_gpu_launch_threads(launch_op, threads=[nb_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + cse(mod) + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # convert vector to xegpu + gpu_mod = match(mod, ops={"gpu.module"}) + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + cse(gpu_func) + + if dump_kernel == "xegpu-initial": + return mod, True + + # add layouts to DPAS op operands + k_loop = match(gpu_func, ops={"scf.for"}) + dpas_op = match(k_loop, ops={"xegpu.dpas"}) + tile_a = transform.GetOperandOp(anyvalue, dpas_op, [0]) + tile_b = transform.GetOperandOp(anyvalue, dpas_op, [1]) + tile_c = transform.GetOperandOp(anyvalue, dpas_op, [2]) + + def convert_layout(value, input, target): + xegpu.convert_layout( + value, + input_sg_layout=input["sg_layout"], + input_sg_data=input["sg_data"], + input_inst_data=input["inst_data"], + target_sg_layout=target["sg_layout"], + target_sg_data=target["sg_data"], + target_inst_data=target["inst_data"], + ) + + # insert prefetch ops for DPAS A and B tiles + desc_prefetch_a = xegpu.insert_prefetch( + tile_a, + nb_prefetch=nb_prefetch, + ) + xegpu.set_desc_layout( + desc_prefetch_a, + sg_layout=prefetch_layout_a, + sg_data=prefetch_tile_a, + inst_data=prefetch_inst_data, + ) + desc_prefetch_b = xegpu.insert_prefetch( + tile_b, + nb_prefetch=nb_prefetch, + ) + xegpu.set_desc_layout( + desc_prefetch_b, + sg_layout=prefetch_layout_b, + sg_data=prefetch_tile_b, + inst_data=prefetch_inst_data, + ) + + # A tile load layout + layout_load_a = { + "sg_layout": sg_layout, + "sg_data": sg_tile_a, + "inst_data": load_tile_a, + } + desc_op_a = xegpu.get_desc_op(tile_a) + desc_op_a = xegpu.set_desc_layout( + target=desc_op_a, + **layout_load_a, + ) + # A tile dpas layout + layout_dpas_a = layout_load_a.copy() + layout_dpas_a["inst_data"] = dpas_shape_a + convert_layout(tile_a, layout_load_a, layout_dpas_a) + + # B tile load layout + layout_load_b = { + "sg_layout": sg_layout, + "sg_data": sg_tile_b, + "inst_data": load_tile_b, + } + desc_op_b = xegpu.get_desc_op(tile_b) + desc_op_b = xegpu.set_desc_layout( + target=desc_op_b, + **layout_load_b, + ) + # B tile dpas layout + layout_dpas_b = layout_load_b.copy() + layout_dpas_b["inst_data"] = dpas_shape_b + convert_layout(tile_b, layout_load_b, layout_dpas_b) + + # C tile layout + output_layout = { + "sg_layout": sg_layout, + "sg_data": sg_tile, + "inst_data": dpas_shape_c, + } + desc_op_c = xegpu.get_desc_op(tile_c) + desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) + # C tile dpas layout + xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout) + + if has_relu: + # for post ops we need to add C layout manually + max_op = match(gpu_func, ops={"arith.maximumf"}).result + xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout) + # find zero constant buffer and annotate it + const_buffer = transform.get_producer_of_operand(anytype, max_op, 1) + xegpu.set_op_layout_attr(const_buffer, result=True, index=0, **output_layout) + if has_bias: + # for post ops we need to add C layout manually + add_op = match(gpu_func, ops={"arith.addf"}).result + xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout) + + # annotate broadcast op operands + bcast_op = transform.get_producer_of_operand(anytype, add_op, 0) + xegpu.set_op_layout_attr(bcast_op, result=True, index=0, **output_layout) + bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0) + xegpu.set_op_layout_attr( + bcast_load, result=True, index=0, **output_layout, slice_dims=[0] + ) + output_layout_dim1 = { + "sg_layout": [sg_layout[1]], + "sg_data": [sg_tile[1]], + "inst_data": [dpas_shape_c[1]], + } + offset = transform.get_producer_of_operand(anytype, bcast_load, 1) + xegpu.set_op_layout_attr(offset, result=True, index=0, **output_layout_dim1) + aux1 = transform.get_producer_of_operand(anytype, offset, 0) + xegpu.set_op_layout_attr(aux1, result=True, index=0, **output_layout_dim1) + aux2 = transform.get_producer_of_operand(anytype, offset, 1) + xegpu.set_op_layout_attr(aux2, result=True, index=0, **output_layout_dim1) + mask = transform.get_producer_of_operand(anytype, bcast_load, 2) + xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1) + raise NotImplementedError("Bias layout propagation is not supported.") + cse(gpu_func) + canonicalize(gpu_func) + + # hoist desc ops out of reduction loop + transform.apply_licm(k_loop) + + canonicalize(gpu_func) + cse(gpu_func) + + if dump_kernel == "xegpu-wg": + return mod, True + + return mod, False + + +def bundle_xegpu_to_binary(mod, dump_kernel=""): + """Schedule for lowering xegpu wg level to binary.""" + # This schedule corresponds to upstream MLIR XeVM lowering pipeline + # and is payload independent. + + gpu_mod = match(mod, ops={"gpu.module"}) + # xegpu distribution + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "xegpu-wg-to-sg-distribute") + cse(gpu_func) + + if dump_kernel == "xegpu-sg": + return mod, True + + gpu_func = apply_registered_pass(gpu_func, "lower-affine") + cse(gpu_func) + gpu_func = apply_registered_pass(gpu_func, "xegpu-blocking") + canonicalize(gpu_func) + cse(gpu_func) + + if dump_kernel == "xegpu-inst": + return mod, True + + gpu_func = apply_registered_pass(gpu_func, "xegpu-propagate-layout") + gpu_mod = apply_registered_pass(gpu_mod, "xegpu-subgroup-distribute") + canonicalize(gpu_mod) + cse(gpu_mod) + gpu_mod = apply_registered_pass(gpu_mod, "loop-invariant-code-motion") + cse(gpu_mod) + gpu_mod = apply_registered_pass(gpu_mod, "xegpu-vector-linearize") + gpu_mod = apply_registered_pass(gpu_mod, "convert-xegpu-to-xevm") + gpu_mod = apply_registered_pass( + gpu_mod, "convert-gpu-to-llvm-spv", options={"use-64bit-index": "true"} + ) + gpu_mod = apply_registered_pass(gpu_mod, "convert-xevm-to-llvm") + cse(gpu_mod) + + func = match(mod, ops={"func.func"}) + func = apply_registered_pass(func, "gpu-async-region") + + mod = apply_registered_pass(mod, "reconcile-unrealized-casts") + mod = apply_registered_pass(mod, "convert-vector-to-scf") + mod = apply_registered_pass(mod, "convert-scf-to-cf") + mod = apply_registered_pass(mod, "expand-strided-metadata") + mod = apply_registered_pass(mod, "finalize-memref-to-llvm") + mod = apply_registered_pass(mod, "convert-cf-to-llvm") + mod = apply_registered_pass(mod, "convert-vector-to-llvm") + mod = apply_registered_pass(mod, "convert-arith-to-llvm") + mod = apply_registered_pass(mod, "convert-index-to-llvm") + mod = apply_registered_pass(mod, "convert-func-to-llvm") + mod = apply_registered_pass(mod, "convert-math-to-llvm") + mod = apply_registered_pass(mod, "gpu-to-llvm") + mod = apply_registered_pass(mod, "lower-affine") + mod = apply_registered_pass(mod, "reconcile-unrealized-casts") + cse(mod) + mod = apply_registered_pass(mod, "gpu-module-to-binary") + + return mod, False From 7f334a075a0488bf6f338a9fb16986d9743fde83 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Tue, 25 Nov 2025 23:26:26 +0200 Subject: [PATCH 02/16] define context and location only once --- .../examples/xegpu_matmul/execution_engine.py | 89 ++++++-------- python/examples/xegpu_matmul/matmul.py | 100 ++++++++-------- .../xegpu_matmul/payload_generator.py | 111 ++++++++---------- python/examples/xegpu_matmul/schedule.py | 72 +++++------- 4 files changed, 168 insertions(+), 204 deletions(-) diff --git a/python/examples/xegpu_matmul/execution_engine.py b/python/examples/xegpu_matmul/execution_engine.py index 562f659..7e6ca83 100644 --- a/python/examples/xegpu_matmul/execution_engine.py +++ b/python/examples/xegpu_matmul/execution_engine.py @@ -14,8 +14,6 @@ def get_engine(payload_module, opt_level=3) -> ExecutionEngine: - context = ir.Context() - location = ir.Location.unknown(context) lib_dir = get_mlir_library_path() libs = [ "libmlir_levelzero_runtime.so", @@ -23,30 +21,26 @@ def get_engine(payload_module, opt_level=3) -> ExecutionEngine: "libmlir_c_runner_utils.so", ] libs = [os.path.join(lib_dir, lib) for lib in libs] - with context, location: - execution_engine = ExecutionEngine( - payload_module, opt_level=opt_level, shared_libs=libs - ) - execution_engine.initialize() + execution_engine = ExecutionEngine( + payload_module, opt_level=opt_level, shared_libs=libs + ) + execution_engine.initialize() return execution_engine def apply_transform_schedule( payload_module, schedule_module, - context, - location, dump_kernel: Optional[str] = None, dump_schedule: bool = False, ): if not dump_kernel or dump_kernel != "initial": - with context, location: - # invoke transform interpreter directly - transform_interpreter.apply_named_sequence( - payload_root=payload_module, - transform_root=schedule_module.body.operations[0], - transform_module=schedule_module, - ) + # invoke transform interpreter directly + transform_interpreter.apply_named_sequence( + payload_root=payload_module, + transform_root=schedule_module.body.operations[0], + transform_module=schedule_module, + ) if dump_kernel: print(payload_module) if dump_schedule: @@ -66,8 +60,6 @@ def lower_payload( apply_transform_schedule( payload_module, schedule_module, - workload.context, - workload.location, dump_kernel=dump_kernel, dump_schedule=dump_schedule, ) @@ -125,43 +117,40 @@ def benchmark( payload_arguments = payload_func.type.inputs # emit benchmark function that calls payload and times it - with workload.context, workload.location: - with ir.InsertionPoint(payload_module.body): - # define rtclock function - f64_t = ir.F64Type.get() - f = func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") - # emit benchmark function - time_memref_t = ir.MemRefType.get((nruns,), f64_t) - args = payload_arguments + [time_memref_t] - f = func.FuncOp("benchmark", (tuple(args), ())) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - index_t = ir.IndexType.get() - zero = arith.ConstantOp(index_t, 0) - one = arith.ConstantOp(index_t, 1) - nwarmup_cst = arith.ConstantOp(index_t, nwarmup) - for_op = scf.ForOp(zero, nwarmup_cst, one) - with ir.InsertionPoint(for_op.body): - func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) - scf.YieldOp(()) - nruns_cst = arith.ConstantOp(index_t, nruns) - for_op = scf.ForOp(zero, nruns_cst, one) - i = for_op.induction_variable - with ir.InsertionPoint(for_op.body): - tic = func.CallOp((f64_t,), "rtclock", ()).result - func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) - toc = func.CallOp((f64_t,), "rtclock", ()).result - time = arith.SubFOp(toc, tic) - memref.StoreOp(time, f.arguments[-1], [i]) - scf.YieldOp(()) - func.ReturnOp(()) + with ir.InsertionPoint(payload_module.body): + # define rtclock function + f64_t = ir.F64Type.get() + f = func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((nruns,), f64_t) + args = payload_arguments + [time_memref_t] + f = func.FuncOp("benchmark", (tuple(args), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + index_t = ir.IndexType.get() + zero = arith.ConstantOp(index_t, 0) + one = arith.ConstantOp(index_t, 1) + nwarmup_cst = arith.ConstantOp(index_t, nwarmup) + for_op = scf.ForOp(zero, nwarmup_cst, one) + with ir.InsertionPoint(for_op.body): + func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) + scf.YieldOp(()) + nruns_cst = arith.ConstantOp(index_t, nruns) + for_op = scf.ForOp(zero, nruns_cst, one) + i = for_op.induction_variable + with ir.InsertionPoint(for_op.body): + tic = func.CallOp((f64_t,), "rtclock", ()).result + func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) + toc = func.CallOp((f64_t,), "rtclock", ()).result + time = arith.SubFOp(toc, tic) + memref.StoreOp(time, f.arguments[-1], [i]) + scf.YieldOp(()) + func.ReturnOp(()) # lower apply_transform_schedule( payload_module, workload.schedule_module(parameters=schedule_parameters), - workload.context, - workload.location, ) # get execution engine, rtclock requires mlir_c_runner engine = get_engine(payload_module) diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index 2a2985b..1e10444 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -58,8 +58,6 @@ def __init__( self.has_relu = has_relu if has_bias: raise NotImplementedError("Bias is not implemented yet") - self.context = ir.Context() - self.location = ir.Location.unknown(context=self.context) # cache allocated memrefs self.gpu_memrefs = {} @@ -192,15 +190,11 @@ def payload_module(self) -> ir.Module: c_type_str=self.c_type, has_bias=self.has_bias, has_relu=self.has_relu, - context=self.context, - location=self.location, ) return mod def schedule_module(self, dump_kernel=None, parameters=None) -> ir.Module: return get_schedule_module( - context=self.context, - location=self.location, has_bias=self.has_bias, has_relu=self.has_relu, dump_kernel=dump_kernel, @@ -351,51 +345,53 @@ def parse_cli(): } M, N, K = args.sizes - wload = XeGPUMatMul( - M=M, - N=N, - K=K, - ab_type=args.ab_type, - c_type=args.c_type, - has_bias=False, - has_relu=args.relu, - ) - if args.dump_kernel or args.dump_schedule: - lower_payload( - wload, - dump_kernel=args.dump_kernel, - dump_schedule=args.dump_schedule, - schedule_parameters=params, - ) - else: - times = benchmark( - wload, - nruns=args.nruns, - nwarmup=15, - check_correctness=args.check_result, - schedule_parameters=params, - verbose=1, + with ir.Context(), ir.Location.unknown(): + wload = XeGPUMatMul( + M=M, + N=N, + K=K, + ab_type=args.ab_type, + c_type=args.c_type, + has_bias=False, + has_relu=args.relu, ) - times *= 1e6 # convert to microseconds - elapsed = np.mean(times) - flop_count = wload.get_complexity()[0] - gflops = flop_count / (elapsed * 1e-6) / 1e9 - - def list2str(a): - return ",".join(map(str, a)) - - parts = [ - f"sizes={list2str(args.sizes)}", - f"dt={args.ab_type},{args.c_type}", - f"wg-tile={list2str(args.wg_tile)}", - f"sg-tile={list2str(args.sg_tile)}", - f"k-tile={args.k_tile}", - f"load-a-tile={list2str(args.load_tile_a)}", - f"load-b-tile={list2str(args.load_tile_b)}", - f"pf-a-tile={list2str(args.prefetch_tile_a)}", - f"pf-b-tile={list2str(args.prefetch_tile_b)}", - f"time(us): {elapsed:.2f}", - f"GFLOPS: {gflops:.2f}", - ] - print(" ".join(parts)) + + if args.dump_kernel or args.dump_schedule: + lower_payload( + wload, + dump_kernel=args.dump_kernel, + dump_schedule=args.dump_schedule, + schedule_parameters=params, + ) + else: + times = benchmark( + wload, + nruns=args.nruns, + nwarmup=15, + check_correctness=args.check_result, + schedule_parameters=params, + verbose=1, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + def list2str(a): + return ",".join(map(str, a)) + + parts = [ + f"sizes={list2str(args.sizes)}", + f"dt={args.ab_type},{args.c_type}", + f"wg-tile={list2str(args.wg_tile)}", + f"sg-tile={list2str(args.sg_tile)}", + f"k-tile={args.k_tile}", + f"load-a-tile={list2str(args.load_tile_a)}", + f"load-b-tile={list2str(args.load_tile_b)}", + f"pf-a-tile={list2str(args.prefetch_tile_a)}", + f"pf-b-tile={list2str(args.prefetch_tile_b)}", + f"time(us): {elapsed:.2f}", + f"GFLOPS: {gflops:.2f}", + ] + print(" ".join(parts)) diff --git a/python/examples/xegpu_matmul/payload_generator.py b/python/examples/xegpu_matmul/payload_generator.py index 0f4d7b9..58ae16d 100644 --- a/python/examples/xegpu_matmul/payload_generator.py +++ b/python/examples/xegpu_matmul/payload_generator.py @@ -1,6 +1,5 @@ from mlir import ir from mlir.dialects import func, linalg, gpu, bufferization, arith, tensor -from typing import Optional def emit_gpu_alloc(mod, suffix, element_type, rank=2): @@ -65,70 +64,60 @@ def generate_matmul_payload( c_type_str: str, has_bias: bool, has_relu: bool, - context: Optional[ir.Context] = None, - location: Optional[ir.Location] = None, ) -> ir.Module: """Generate payload function module.""" - if context is None: - context = ir.Context() - if location is None: - location = ir.Location.unknown(context) - - with context, location: - get_ir_dtype = { - "f16": ir.F16Type.get(), - "f32": ir.F32Type.get(), - } - ab_type = get_ir_dtype[ab_type_str] - c_type = get_ir_dtype[c_type_str] - tensor_a_t = ir.RankedTensorType.get((M, K), ab_type) - tensor_b_t = ir.RankedTensorType.get((K, N), ab_type) - tensor_c_t = ir.RankedTensorType.get((M, N), c_type) - tensor_bias_t = ir.RankedTensorType.get((N,), c_type) - memref_a_t = ir.MemRefType.get((M, K), ab_type) - memref_b_t = ir.MemRefType.get((K, N), ab_type) - memref_c_t = ir.MemRefType.get((M, N), c_type) - memref_bias_t = ir.MemRefType.get((N,), c_type) - mod = ir.Module.create() - with ir.InsertionPoint(mod.body): - fargs = [memref_a_t, memref_b_t] - if has_bias: - fargs.append(memref_bias_t) - fargs.append(memref_c_t) - f = func.FuncOp(func_name, (tuple(fargs), ())) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - A = f.arguments[0] - B = f.arguments[1] - C = f.arguments[-1] - a_tensor = bufferization.ToTensorOp(tensor_a_t, A, restrict=True) - b_tensor = bufferization.ToTensorOp(tensor_b_t, B, restrict=True) - c_tensor = bufferization.ToTensorOp( - tensor_c_t, C, restrict=True, writable=True + get_ir_dtype = { + "f16": ir.F16Type.get(), + "f32": ir.F32Type.get(), + } + ab_type = get_ir_dtype[ab_type_str] + c_type = get_ir_dtype[c_type_str] + tensor_a_t = ir.RankedTensorType.get((M, K), ab_type) + tensor_b_t = ir.RankedTensorType.get((K, N), ab_type) + tensor_c_t = ir.RankedTensorType.get((M, N), c_type) + tensor_bias_t = ir.RankedTensorType.get((N,), c_type) + memref_a_t = ir.MemRefType.get((M, K), ab_type) + memref_b_t = ir.MemRefType.get((K, N), ab_type) + memref_c_t = ir.MemRefType.get((M, N), c_type) + memref_bias_t = ir.MemRefType.get((N,), c_type) + mod = ir.Module.create() + with ir.InsertionPoint(mod.body): + fargs = [memref_a_t, memref_b_t] + if has_bias: + fargs.append(memref_bias_t) + fargs.append(memref_c_t) + f = func.FuncOp(func_name, (tuple(fargs), ())) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + A = f.arguments[0] + B = f.arguments[1] + C = f.arguments[-1] + a_tensor = bufferization.ToTensorOp(tensor_a_t, A, restrict=True) + b_tensor = bufferization.ToTensorOp(tensor_b_t, B, restrict=True) + c_tensor = bufferization.ToTensorOp(tensor_c_t, C, restrict=True, writable=True) + mmul = linalg.matmul(a_tensor, b_tensor, outs=[c_tensor]) + terminal = mmul + if has_bias: + bias = f.arguments[2] + bias_tensor = bufferization.ToTensorOp( + tensor_bias_t, bias, restrict=True, writable=True ) - mmul = linalg.matmul(a_tensor, b_tensor, outs=[c_tensor]) - terminal = mmul - if has_bias: - bias = f.arguments[2] - bias_tensor = bufferization.ToTensorOp( - tensor_bias_t, bias, restrict=True, writable=True - ) - empty = tensor.empty((M, N), c_type) - bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) - terminal = linalg.add(bcast, terminal, outs=[empty]) - if has_relu: - zero = arith.constant(c_type, 0.0) - empty = tensor.empty((M, N), c_type) - zero_tensor = linalg.fill(zero, outs=[empty]) - terminal = linalg.max(terminal, zero_tensor, outs=[empty]) + empty = tensor.empty((M, N), c_type) + bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) + terminal = linalg.add(bcast, terminal, outs=[empty]) + if has_relu: + zero = arith.constant(c_type, 0.0) + empty = tensor.empty((M, N), c_type) + zero_tensor = linalg.fill(zero, outs=[empty]) + terminal = linalg.max(terminal, zero_tensor, outs=[empty]) - bufferization.MaterializeInDestinationOp( - None, terminal, C, restrict=True, writable=True - ) - func.ReturnOp(()) + bufferization.MaterializeInDestinationOp( + None, terminal, C, restrict=True, writable=True + ) + func.ReturnOp(()) - emit_gpu_util_funcs(mod, ab_type) - if c_type != ab_type: - emit_gpu_util_funcs(mod, c_type) + emit_gpu_util_funcs(mod, ab_type) + if c_type != ab_type: + emit_gpu_util_funcs(mod, c_type) return mod diff --git a/python/examples/xegpu_matmul/schedule.py b/python/examples/xegpu_matmul/schedule.py index 7aba91b..536c1d9 100644 --- a/python/examples/xegpu_matmul/schedule.py +++ b/python/examples/xegpu_matmul/schedule.py @@ -19,50 +19,40 @@ def get_schedule_module( has_relu=False, dump_kernel="", params=None, - context=None, - location=None, ) -> ir.Module: """Generate transform schedule module.""" - if context is None: - context = ir.Context() - if location is None: - location = ir.Location.unknown(context) - - with context, location: - mod = ir.Module.create() - mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() - with ir.InsertionPoint(mod.body): - named_sequence = transform.NamedSequenceOp( - "__transform_main", - [transform.AnyOpType.get()], # input types - [], # output types - arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], - ) - with ir.InsertionPoint(named_sequence.body): - xegpu_matmul_transform_schedule( - named_sequence, - has_bias=has_bias, - has_relu=has_relu, - dump_kernel=dump_kernel, - params=params, - ) - # placeholder for parameter division op - i32 = ir.IntegerType.get_signless(32) - paramInt32Type = transform.ParamType.get(i32) - div_named_sequence = transform.NamedSequenceOp( - "param_div", - [paramInt32Type, paramInt32Type], # input types - [paramInt32Type], # output types - arg_attrs=[ - {"transform.readonly": ir.UnitAttr.get()}, - {"transform.readonly": ir.UnitAttr.get()}, - ], + mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + with ir.InsertionPoint(mod.body): + named_sequence = transform.NamedSequenceOp( + "__transform_main", + [transform.AnyOpType.get()], # input types + [], # output types + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + with ir.InsertionPoint(named_sequence.body): + xegpu_matmul_transform_schedule( + named_sequence, + has_bias=has_bias, + has_relu=has_relu, + dump_kernel=dump_kernel, + params=params, ) - with ir.InsertionPoint(div_named_sequence.body): - p = transform.ParamConstantOp( - paramInt32Type, ir.IntegerAttr.get(i32, 1) - ) - transform.YieldOp(p) + # placeholder for parameter division op + i32 = ir.IntegerType.get_signless(32) + paramInt32Type = transform.ParamType.get(i32) + div_named_sequence = transform.NamedSequenceOp( + "param_div", + [paramInt32Type, paramInt32Type], # input types + [paramInt32Type], # output types + arg_attrs=[ + {"transform.readonly": ir.UnitAttr.get()}, + {"transform.readonly": ir.UnitAttr.get()}, + ], + ) + with ir.InsertionPoint(div_named_sequence.body): + p = transform.ParamConstantOp(paramInt32Type, ir.IntegerAttr.get(i32, 1)) + transform.YieldOp(p) return mod From 8e8c9e0f25e5ae5ac398e5f5db2d5f1a8a7200c1 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Tue, 25 Nov 2025 23:44:07 +0200 Subject: [PATCH 03/16] funciong arg typing --- .../examples/xegpu_matmul/execution_engine.py | 22 +++---- python/examples/xegpu_matmul/matmul.py | 37 ++++++++--- .../xegpu_matmul/payload_generator.py | 8 +-- python/examples/xegpu_matmul/schedule.py | 63 +++++-------------- 4 files changed, 57 insertions(+), 73 deletions(-) diff --git a/python/examples/xegpu_matmul/execution_engine.py b/python/examples/xegpu_matmul/execution_engine.py index 7e6ca83..63cfbc4 100644 --- a/python/examples/xegpu_matmul/execution_engine.py +++ b/python/examples/xegpu_matmul/execution_engine.py @@ -3,7 +3,6 @@ import os from typing import Optional -from mlir.dialects.transform import interpreter as transform_interpreter from mlir.dialects import func, arith, scf, memref from mlir.execution_engine import ExecutionEngine from mlir import ir @@ -13,7 +12,7 @@ from mlir_utils import get_mlir_library_path -def get_engine(payload_module, opt_level=3) -> ExecutionEngine: +def get_engine(payload_module: ir.Module, opt_level: int = 3) -> ExecutionEngine: lib_dir = get_mlir_library_path() libs = [ "libmlir_levelzero_runtime.so", @@ -29,18 +28,15 @@ def get_engine(payload_module, opt_level=3) -> ExecutionEngine: def apply_transform_schedule( - payload_module, - schedule_module, + payload_module: ir.Module, + schedule_module: ir.Module, dump_kernel: Optional[str] = None, dump_schedule: bool = False, ): if not dump_kernel or dump_kernel != "initial": - # invoke transform interpreter directly - transform_interpreter.apply_named_sequence( - payload_root=payload_module, - transform_root=schedule_module.body.operations[0], - transform_module=schedule_module, - ) + # apply schedule on payload module + named_seq = schedule_module.body.operations[0] + named_seq.apply(payload_module) if dump_kernel: print(payload_module) if dump_schedule: @@ -48,7 +44,7 @@ def apply_transform_schedule( def lower_payload( - workload, + workload: object, dump_kernel: Optional[str] = None, dump_schedule: bool = False, schedule_parameters: Optional[dict] = None, @@ -67,7 +63,7 @@ def lower_payload( def execute( - workload, + workload: object, check_correctness: bool = True, schedule_parameters: Optional[dict] = None, verbose: int = 0, @@ -94,7 +90,7 @@ def execute( def benchmark( - workload, + workload: object, nruns: int = 100, nwarmup: int = 10, schedule_parameters: Optional[dict] = None, diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index 1e10444..097dbd5 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -9,6 +9,8 @@ make_nd_memref_descriptor, as_ctype, ) +from mlir.execution_engine import ExecutionEngine +from typing import Optional import ctypes from contextlib import contextmanager from functools import cached_property @@ -20,7 +22,7 @@ import argparse -def numpy_to_ctype(arr) -> ctypes._Pointer: +def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: """Convert numpy array to memref and ctypes **void pointer.""" return memref_to_ctype(get_ranked_memref_descriptor(arr)) @@ -38,7 +40,14 @@ class XeGPUMatMul: payload_function_name = "payload" def __init__( - self, M, N, K, ab_type="f32", c_type="f32", has_bias=False, has_relu=False + self, + M: int, + N: int, + K: int, + ab_type: str = "f32", + c_type: str = "f32", + has_bias: bool = False, + has_relu: bool = False, ): self.M = M self.N = N @@ -62,7 +71,11 @@ def __init__( self.gpu_memrefs = {} def _allocate_array( - self, name, shape, dtype_str, execution_engine + self, + name: str, + shape: tuple[int, ...], + dtype_str: str, + execution_engine: ExecutionEngine, ) -> ctypes.Structure: key = (name, dtype_str) if key in self.gpu_memrefs: @@ -79,12 +92,12 @@ def _allocate_array( self.gpu_memrefs[key] = mref return mref - def _allocate_inputs(self, execution_engine): + def _allocate_inputs(self, execution_engine: ExecutionEngine): self._allocate_array("A", self.a_shape, self.ab_type, execution_engine) self._allocate_array("B", self.b_shape, self.ab_type, execution_engine) self._allocate_array("C", self.c_shape, self.c_type, execution_engine) - def _deallocate_all(self, execution_engine): + def _deallocate_all(self, execution_engine: ExecutionEngine): for (_, dtype_str), mref in self.gpu_memrefs.items(): dealloc_func = execution_engine.lookup("gpu_dealloc_" + dtype_str) ptr_mref = ctypes.pointer(ctypes.pointer(mref)) @@ -92,7 +105,7 @@ def _deallocate_all(self, execution_engine): self.gpu_memrefs = {} @contextmanager - def allocate(self, execution_engine): + def allocate(self, execution_engine: ExecutionEngine): try: self._allocate_inputs(execution_engine) yield None @@ -128,7 +141,9 @@ def _reference_solution(self) -> np.ndarray: raise NotImplementedError("Bias verification not implemented") return C_ref - def get_input_arrays(self, execution_engine) -> list[ctypes.Structure]: + def get_input_arrays( + self, execution_engine: ExecutionEngine + ) -> list[ctypes.Structure]: A_gpu = self._allocate_array("A", self.a_shape, self.ab_type, execution_engine) B_gpu = self._allocate_array("B", self.b_shape, self.ab_type, execution_engine) C_gpu = self._allocate_array("C", self.c_shape, self.c_type, execution_engine) @@ -144,7 +159,9 @@ def get_input_arrays(self, execution_engine) -> list[ctypes.Structure]: # return memrefs for the payload function return [A_gpu, B_gpu, C_gpu] - def check_correctness(self, execution_engine, verbose: int = 0) -> bool: + def check_correctness( + self, execution_engine: ExecutionEngine, verbose: int = 0 + ) -> bool: # copy result from device to host C_gpu = self.gpu_memrefs[("C", self.c_type)] C_host_copy = np.zeros((self.M, self.N), dtype=self.c_dtype) @@ -193,7 +210,9 @@ def payload_module(self) -> ir.Module: ) return mod - def schedule_module(self, dump_kernel=None, parameters=None) -> ir.Module: + def schedule_module( + self, dump_kernel: str = None, parameters: Optional[dict] = None + ) -> ir.Module: return get_schedule_module( has_bias=self.has_bias, has_relu=self.has_relu, diff --git a/python/examples/xegpu_matmul/payload_generator.py b/python/examples/xegpu_matmul/payload_generator.py index 58ae16d..7ceb78c 100644 --- a/python/examples/xegpu_matmul/payload_generator.py +++ b/python/examples/xegpu_matmul/payload_generator.py @@ -2,7 +2,7 @@ from mlir.dialects import func, linalg, gpu, bufferization, arith, tensor -def emit_gpu_alloc(mod, suffix, element_type, rank=2): +def emit_gpu_alloc(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) index_t = ir.IndexType.get() @@ -19,7 +19,7 @@ def emit_gpu_alloc(mod, suffix, element_type, rank=2): func.ReturnOp((alloc,)) -def emit_gpu_dealloc(mod, suffix, element_type, rank=2): +def emit_gpu_dealloc(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) with ir.InsertionPoint(mod.body): @@ -30,7 +30,7 @@ def emit_gpu_dealloc(mod, suffix, element_type, rank=2): func.ReturnOp(()) -def emit_gpu_copy(mod, suffix, element_type, rank=2): +def emit_gpu_copy(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): """Emit GPU copy function.""" dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) @@ -44,7 +44,7 @@ def emit_gpu_copy(mod, suffix, element_type, rank=2): func.ReturnOp(()) -def emit_gpu_util_funcs(mod, element_type): +def emit_gpu_util_funcs(mod: ir.Module, element_type: ir.Type): """Emit GPU utility functions for allocation, deallocation and copy.""" suffix = { ir.F16Type.get(): "f16", diff --git a/python/examples/xegpu_matmul/schedule.py b/python/examples/xegpu_matmul/schedule.py index 536c1d9..4b0cddf 100644 --- a/python/examples/xegpu_matmul/schedule.py +++ b/python/examples/xegpu_matmul/schedule.py @@ -6,6 +6,7 @@ from mlir.dialects import transform from mlir.dialects.transform import structured from mlir_utils import apply_registered_pass, match, cse, canonicalize +from typing import Optional # hardware constraints @@ -15,10 +16,10 @@ def get_schedule_module( - has_bias=False, - has_relu=False, - dump_kernel="", - params=None, + has_bias: bool = False, + has_relu: bool = False, + dump_kernel: str = "", + params: Optional[dict] = None, ) -> ir.Module: """Generate transform schedule module.""" mod = ir.Module.create() @@ -58,11 +59,11 @@ def get_schedule_module( def xegpu_matmul_transform_schedule( - named_sequence, - has_bias=False, - has_relu=False, - dump_kernel="", - params=None, + named_sequence: transform.NamedSequenceOp, + has_bias: bool = False, + has_relu: bool = False, + dump_kernel: str = "", + params: Optional[dict] = None, ): """Transform schedule for matmul-like payload.""" mod = bundle_header(named_sequence) @@ -84,7 +85,7 @@ def xegpu_matmul_transform_schedule( transform.YieldOp() -def bundle_header(named_sequence): +def bundle_header(named_sequence: transform.NamedSequenceOp): """Matches the payload module.""" anytype = transform.AnyOpType.get() func = match(named_sequence.bodyTarget, ops={"func.func"}) @@ -97,44 +98,12 @@ def bundle_header(named_sequence): return mod -def geo_range(start, stop, factor): - """ - Returns a geometric range dict attribute. - - `stop` is inclusive. - """ - i32 = ir.IntegerType.get_signless(32) - return ir.DictAttr.get( - { - "start": ir.IntegerAttr.get(i32, start), - "stop": ir.IntegerAttr.get(i32, stop + 1), - "factor": ir.IntegerAttr.get(i32, factor), - } - ) - - -def lin_range(start, stop, step): - """ - Returns a linear range dict attribute. - - `stop` is inclusive. - """ - i32 = ir.IntegerType.get_signless(32) - return ir.DictAttr.get( - { - "start": ir.IntegerAttr.get(i32, start), - "stop": ir.IntegerAttr.get(i32, stop + 1), - "step": ir.IntegerAttr.get(i32, step), - } - ) - - def bundle_xepu_matmul_schedule( mod, - has_bias=False, - has_relu=False, - dump_kernel="", - params=None, + has_bias: bool = False, + has_relu: bool = False, + dump_kernel: str = "", + params: Optional[dict] = None, ): """Schedule for lowering matmul-like payload to xegpu wg level.""" if params is None: @@ -416,7 +385,7 @@ def convert_layout(value, input, target): return mod, False -def bundle_xegpu_to_binary(mod, dump_kernel=""): +def bundle_xegpu_to_binary(mod, dump_kernel: str = ""): """Schedule for lowering xegpu wg level to binary.""" # This schedule corresponds to upstream MLIR XeVM lowering pipeline # and is payload independent. From 50a34bdef9019ac0680250b1ad02b11ea0eaf8ce Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 26 Nov 2025 18:13:09 +0200 Subject: [PATCH 04/16] func.func decorator and snake_case op names --- .../examples/xegpu_matmul/execution_engine.py | 47 ++++---- .../xegpu_matmul/payload_generator.py | 102 +++++++++--------- 2 files changed, 76 insertions(+), 73 deletions(-) diff --git a/python/examples/xegpu_matmul/execution_engine.py b/python/examples/xegpu_matmul/execution_engine.py index 63cfbc4..ff679e8 100644 --- a/python/examples/xegpu_matmul/execution_engine.py +++ b/python/examples/xegpu_matmul/execution_engine.py @@ -105,7 +105,7 @@ def benchmark( for op in payload_module.operation.regions[0].blocks[0]: if ( isinstance(op, func.FuncOp) - and str(op.name).strip('"') == workload.payload_function_name + and op.name.value == workload.payload_function_name ): payload_func = op break @@ -116,32 +116,31 @@ def benchmark( with ir.InsertionPoint(payload_module.body): # define rtclock function f64_t = ir.F64Type.get() - f = func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") # emit benchmark function time_memref_t = ir.MemRefType.get((nruns,), f64_t) args = payload_arguments + [time_memref_t] - f = func.FuncOp("benchmark", (tuple(args), ())) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - index_t = ir.IndexType.get() - zero = arith.ConstantOp(index_t, 0) - one = arith.ConstantOp(index_t, 1) - nwarmup_cst = arith.ConstantOp(index_t, nwarmup) - for_op = scf.ForOp(zero, nwarmup_cst, one) - with ir.InsertionPoint(for_op.body): - func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) - scf.YieldOp(()) - nruns_cst = arith.ConstantOp(index_t, nruns) - for_op = scf.ForOp(zero, nruns_cst, one) - i = for_op.induction_variable - with ir.InsertionPoint(for_op.body): - tic = func.CallOp((f64_t,), "rtclock", ()).result - func.CallOp(payload_func, list(f.arguments[: len(payload_arguments)])) - toc = func.CallOp((f64_t,), "rtclock", ()).result - time = arith.SubFOp(toc, tic) - memref.StoreOp(time, f.arguments[-1], [i]) - scf.YieldOp(()) - func.ReturnOp(()) + + @func.func(*args) + def benchmark(*args): + index_t = ir.IndexType.get() + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + nwarmup_cst = arith.constant(index_t, nwarmup) + for i in scf.for_(zero, nwarmup_cst, one): + # FIXME(upstream): func.call is broken for this use case? + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + scf.yield_(()) + nruns_cst = arith.constant(index_t, nruns) + for i in scf.for_(zero, nruns_cst, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, args[-1], [i]) + scf.yield_(()) + + benchmark.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() # lower apply_transform_schedule( diff --git a/python/examples/xegpu_matmul/payload_generator.py b/python/examples/xegpu_matmul/payload_generator.py index 7ceb78c..c8c6cb7 100644 --- a/python/examples/xegpu_matmul/payload_generator.py +++ b/python/examples/xegpu_matmul/payload_generator.py @@ -8,26 +8,27 @@ def emit_gpu_alloc(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int index_t = ir.IndexType.get() i32_t = ir.IntegerType.get_signless(32) with ir.InsertionPoint(mod.body): - f = func.FuncOp("gpu_alloc_" + suffix, (rank * (i32_t,), (memref_dyn_t,))) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - dims = [ - arith.IndexCastOp(index_t, f.arguments[0]), - arith.IndexCastOp(index_t, f.arguments[1]), - ] - alloc = gpu.alloc(memref_dyn_t, None, [], dims, []) - func.ReturnOp((alloc,)) + inputs = rank * (i32_t,) + + @func.func(*inputs, name="gpu_alloc_" + suffix) + def alloc_func(*shape): + dims = [arith.index_cast(index_t, a) for a in shape] + alloc = gpu.alloc(memref_dyn_t, None, [], dims, []) + return alloc + + alloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() def emit_gpu_dealloc(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) with ir.InsertionPoint(mod.body): - f = func.FuncOp("gpu_dealloc_" + suffix, ((memref_dyn_t,), ())) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - gpu.dealloc(None, [], f.arguments[0]) - func.ReturnOp(()) + + @func.func(memref_dyn_t, name="gpu_dealloc_" + suffix) + def dealloc_func(memref): + gpu.dealloc(None, [], memref) + + dealloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() def emit_gpu_copy(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): @@ -35,13 +36,12 @@ def emit_gpu_copy(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) with ir.InsertionPoint(mod.body): - f = func.FuncOp("gpu_copy_" + suffix, ((memref_dyn_t, memref_dyn_t), ())) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - src = f.arguments[0] - dst = f.arguments[1] - gpu.memcpy(None, [], dst, src) - func.ReturnOp(()) + + @func.func(memref_dyn_t, memref_dyn_t, name="gpu_copy_" + suffix) + def copy_func(src, dst): + gpu.memcpy(None, [], dst, src) + + copy_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() def emit_gpu_util_funcs(mod: ir.Module, element_type: ir.Type): @@ -86,35 +86,39 @@ def generate_matmul_payload( if has_bias: fargs.append(memref_bias_t) fargs.append(memref_c_t) - f = func.FuncOp(func_name, (tuple(fargs), ())) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - A = f.arguments[0] - B = f.arguments[1] - C = f.arguments[-1] - a_tensor = bufferization.ToTensorOp(tensor_a_t, A, restrict=True) - b_tensor = bufferization.ToTensorOp(tensor_b_t, B, restrict=True) - c_tensor = bufferization.ToTensorOp(tensor_c_t, C, restrict=True, writable=True) - mmul = linalg.matmul(a_tensor, b_tensor, outs=[c_tensor]) - terminal = mmul - if has_bias: - bias = f.arguments[2] - bias_tensor = bufferization.ToTensorOp( - tensor_bias_t, bias, restrict=True, writable=True + + @func.func(*fargs, name=func_name) + def payload(*args): + A = args[0] + B = args[1] + C = args[-1] + a_tensor = bufferization.to_tensor(tensor_a_t, A, restrict=True) + b_tensor = bufferization.to_tensor(tensor_b_t, B, restrict=True) + c_tensor = bufferization.to_tensor( + tensor_c_t, C, restrict=True, writable=True + ) + + mmul = linalg.matmul(a_tensor, b_tensor, outs=[c_tensor]) + terminal = mmul + if has_bias: + bias = args[2] + bias_tensor = bufferization.to_tensor( + tensor_bias_t, bias, restrict=True, writable=True + ) + empty = tensor.empty((M, N), c_type) + bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) + terminal = linalg.add(bcast, terminal, outs=[empty]) + if has_relu: + zero = arith.constant(c_type, 0.0) + empty = tensor.empty((M, N), c_type) + zero_tensor = linalg.fill(zero, outs=[empty]) + terminal = linalg.max(terminal, zero_tensor, outs=[empty]) + + bufferization.materialize_in_destination( + None, terminal, C, restrict=True, writable=True ) - empty = tensor.empty((M, N), c_type) - bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) - terminal = linalg.add(bcast, terminal, outs=[empty]) - if has_relu: - zero = arith.constant(c_type, 0.0) - empty = tensor.empty((M, N), c_type) - zero_tensor = linalg.fill(zero, outs=[empty]) - terminal = linalg.max(terminal, zero_tensor, outs=[empty]) - - bufferization.MaterializeInDestinationOp( - None, terminal, C, restrict=True, writable=True - ) - func.ReturnOp(()) + + payload.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() emit_gpu_util_funcs(mod, ab_type) if c_type != ab_type: From f8b1a67a8d2659b06dad5f8531fa3c26a6506296 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 26 Nov 2025 18:32:21 +0200 Subject: [PATCH 05/16] payload_generator: do not pass module to helper func generators --- .../xegpu_matmul/payload_generator.py | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/python/examples/xegpu_matmul/payload_generator.py b/python/examples/xegpu_matmul/payload_generator.py index c8c6cb7..0cf3a45 100644 --- a/python/examples/xegpu_matmul/payload_generator.py +++ b/python/examples/xegpu_matmul/payload_generator.py @@ -2,57 +2,54 @@ from mlir.dialects import func, linalg, gpu, bufferization, arith, tensor -def emit_gpu_alloc(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): +def emit_gpu_alloc(suffix: str, element_type: ir.Type, rank: int = 2): dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) index_t = ir.IndexType.get() i32_t = ir.IntegerType.get_signless(32) - with ir.InsertionPoint(mod.body): - inputs = rank * (i32_t,) + inputs = rank * (i32_t,) - @func.func(*inputs, name="gpu_alloc_" + suffix) - def alloc_func(*shape): - dims = [arith.index_cast(index_t, a) for a in shape] - alloc = gpu.alloc(memref_dyn_t, None, [], dims, []) - return alloc + @func.func(*inputs, name="gpu_alloc_" + suffix) + def alloc_func(*shape): + dims = [arith.index_cast(index_t, a) for a in shape] + alloc = gpu.alloc(memref_dyn_t, None, [], dims, []) + return alloc - alloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + alloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() -def emit_gpu_dealloc(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): +def emit_gpu_dealloc(suffix: str, element_type: ir.Type, rank: int = 2): dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) - with ir.InsertionPoint(mod.body): - @func.func(memref_dyn_t, name="gpu_dealloc_" + suffix) - def dealloc_func(memref): - gpu.dealloc(None, [], memref) + @func.func(memref_dyn_t, name="gpu_dealloc_" + suffix) + def dealloc_func(memref): + gpu.dealloc(None, [], memref) - dealloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + dealloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() -def emit_gpu_copy(mod: ir.Module, suffix: str, element_type: ir.Type, rank: int = 2): +def emit_gpu_copy(suffix: str, element_type: ir.Type, rank: int = 2): """Emit GPU copy function.""" dyn = ir.ShapedType.get_dynamic_size() memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) - with ir.InsertionPoint(mod.body): - @func.func(memref_dyn_t, memref_dyn_t, name="gpu_copy_" + suffix) - def copy_func(src, dst): - gpu.memcpy(None, [], dst, src) + @func.func(memref_dyn_t, memref_dyn_t, name="gpu_copy_" + suffix) + def copy_func(src, dst): + gpu.memcpy(None, [], dst, src) - copy_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + copy_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() -def emit_gpu_util_funcs(mod: ir.Module, element_type: ir.Type): +def emit_gpu_util_funcs(element_type: ir.Type): """Emit GPU utility functions for allocation, deallocation and copy.""" suffix = { ir.F16Type.get(): "f16", ir.F32Type.get(): "f32", }[element_type] - emit_gpu_alloc(mod, suffix, element_type) - emit_gpu_dealloc(mod, suffix, element_type) - emit_gpu_copy(mod, suffix, element_type) + emit_gpu_alloc(suffix, element_type) + emit_gpu_dealloc(suffix, element_type) + emit_gpu_copy(suffix, element_type) def generate_matmul_payload( @@ -120,8 +117,8 @@ def payload(*args): payload.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - emit_gpu_util_funcs(mod, ab_type) - if c_type != ab_type: - emit_gpu_util_funcs(mod, c_type) + emit_gpu_util_funcs(ab_type) + if c_type != ab_type: + emit_gpu_util_funcs(c_type) return mod From 80fe8338a87fd779505b85bd616d324946e16f5e Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 26 Nov 2025 19:06:45 +0200 Subject: [PATCH 06/16] simplify schedule and use snake_case op names where possible --- python/examples/xegpu_matmul/schedule.py | 56 ++++++++---------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/python/examples/xegpu_matmul/schedule.py b/python/examples/xegpu_matmul/schedule.py index 4b0cddf..7ee61de 100644 --- a/python/examples/xegpu_matmul/schedule.py +++ b/python/examples/xegpu_matmul/schedule.py @@ -25,48 +25,41 @@ def get_schedule_module( mod = ir.Module.create() mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() with ir.InsertionPoint(mod.body): - named_sequence = transform.NamedSequenceOp( + named_sequence = transform.named_sequence( "__transform_main", [transform.AnyOpType.get()], # input types [], # output types arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], ) with ir.InsertionPoint(named_sequence.body): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) xegpu_matmul_transform_schedule( - named_sequence, + payload_mod, has_bias=has_bias, has_relu=has_relu, dump_kernel=dump_kernel, params=params, ) - # placeholder for parameter division op - i32 = ir.IntegerType.get_signless(32) - paramInt32Type = transform.ParamType.get(i32) - div_named_sequence = transform.NamedSequenceOp( - "param_div", - [paramInt32Type, paramInt32Type], # input types - [paramInt32Type], # output types - arg_attrs=[ - {"transform.readonly": ir.UnitAttr.get()}, - {"transform.readonly": ir.UnitAttr.get()}, - ], - ) - with ir.InsertionPoint(div_named_sequence.body): - p = transform.ParamConstantOp(paramInt32Type, ir.IntegerAttr.get(i32, 1)) - transform.YieldOp(p) return mod def xegpu_matmul_transform_schedule( - named_sequence: transform.NamedSequenceOp, + mod: ir.Value, has_bias: bool = False, has_relu: bool = False, dump_kernel: str = "", params: Optional[dict] = None, ): """Transform schedule for matmul-like payload.""" - mod = bundle_header(named_sequence) mod, interrupted = bundle_xepu_matmul_schedule( mod, has_bias=has_bias, @@ -75,27 +68,14 @@ def xegpu_matmul_transform_schedule( params=params, ) if interrupted: - transform.YieldOp() + transform.yield_() return mod, interrupted = bundle_xegpu_to_binary( mod, dump_kernel=dump_kernel, ) - transform.YieldOp() - - -def bundle_header(named_sequence: transform.NamedSequenceOp): - """Matches the payload module.""" - anytype = transform.AnyOpType.get() - func = match(named_sequence.bodyTarget, ops={"func.func"}) - mod = transform.get_parent_op( - anytype, - func, - op_name="builtin.module", - deduplicate=True, - ) - return mod + transform.yield_() def bundle_xepu_matmul_schedule( @@ -217,7 +197,7 @@ def bundle_xepu_matmul_schedule( # convert forall to parallel wg_loop = match(mod, ops={"scf.forall"}) - wg_loop = loop.ForallToParallelOp([anytype], wg_loop) + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) func = transform.get_parent_op(anytype, wg_loop) # convert to scf.parallel to gpu.launch @@ -257,9 +237,9 @@ def bundle_xepu_matmul_schedule( # add layouts to DPAS op operands k_loop = match(gpu_func, ops={"scf.for"}) dpas_op = match(k_loop, ops={"xegpu.dpas"}) - tile_a = transform.GetOperandOp(anyvalue, dpas_op, [0]) - tile_b = transform.GetOperandOp(anyvalue, dpas_op, [1]) - tile_c = transform.GetOperandOp(anyvalue, dpas_op, [2]) + tile_a = transform.get_operand(anyvalue, dpas_op, [0]) + tile_b = transform.get_operand(anyvalue, dpas_op, [1]) + tile_c = transform.get_operand(anyvalue, dpas_op, [2]) def convert_layout(value, input, target): xegpu.convert_layout( From 804045eee827f839c7dd77b57188de5e647e9ce6 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 26 Nov 2025 21:54:12 +0200 Subject: [PATCH 07/16] schedule: add comment about gpu-lower-to-xevm-pipeline --- python/examples/xegpu_matmul/schedule.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/examples/xegpu_matmul/schedule.py b/python/examples/xegpu_matmul/schedule.py index 7ee61de..86158e2 100644 --- a/python/examples/xegpu_matmul/schedule.py +++ b/python/examples/xegpu_matmul/schedule.py @@ -370,6 +370,11 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""): # This schedule corresponds to upstream MLIR XeVM lowering pipeline # and is payload independent. + # TODO applying gpu-lower-to-xevm-pipeline pass affects performance + # mod = apply_registered_pass( + # mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} + # ) + gpu_mod = match(mod, ops={"gpu.module"}) # xegpu distribution gpu_func = match(gpu_mod, ops={"gpu.func"}) From 116482e550e4c1ecb707c48018a6c3aeee03fb0c Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 27 Nov 2025 11:11:48 +0200 Subject: [PATCH 08/16] rename execution_engine.py -> runner.py --- python/examples/xegpu_matmul/matmul.py | 2 +- python/examples/xegpu_matmul/{execution_engine.py => runner.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename python/examples/xegpu_matmul/{execution_engine.py => runner.py} (100%) diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index 097dbd5..3291061 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -18,7 +18,7 @@ from schedule import get_schedule_module from payload_generator import generate_matmul_payload -from execution_engine import lower_payload, benchmark +from runner import lower_payload, benchmark import argparse diff --git a/python/examples/xegpu_matmul/execution_engine.py b/python/examples/xegpu_matmul/runner.py similarity index 100% rename from python/examples/xegpu_matmul/execution_engine.py rename to python/examples/xegpu_matmul/runner.py From 94205b5e72b426968aff7b6126b6322af9525640 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 27 Nov 2025 11:13:40 +0200 Subject: [PATCH 09/16] rename payload_generator.py -> payload.py --- python/examples/xegpu_matmul/matmul.py | 2 +- .../examples/xegpu_matmul/{payload_generator.py => payload.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename python/examples/xegpu_matmul/{payload_generator.py => payload.py} (100%) diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index 3291061..26b849a 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -16,7 +16,7 @@ from functools import cached_property from lighthouse.utils import get_packed_arg, memref_to_ctype from schedule import get_schedule_module -from payload_generator import generate_matmul_payload +from payload import generate_matmul_payload from runner import lower_payload, benchmark import argparse diff --git a/python/examples/xegpu_matmul/payload_generator.py b/python/examples/xegpu_matmul/payload.py similarity index 100% rename from python/examples/xegpu_matmul/payload_generator.py rename to python/examples/xegpu_matmul/payload.py From dcd62911707bc1c46e502a4e9dd746b21bbb642d Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 27 Nov 2025 11:20:18 +0200 Subject: [PATCH 10/16] simplify allocation: context manager returns input memrefs --- python/examples/xegpu_matmul/matmul.py | 13 ++++--------- python/examples/xegpu_matmul/runner.py | 6 ++---- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index 26b849a..4ac7230 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -92,11 +92,6 @@ def _allocate_array( self.gpu_memrefs[key] = mref return mref - def _allocate_inputs(self, execution_engine: ExecutionEngine): - self._allocate_array("A", self.a_shape, self.ab_type, execution_engine) - self._allocate_array("B", self.b_shape, self.ab_type, execution_engine) - self._allocate_array("C", self.c_shape, self.c_type, execution_engine) - def _deallocate_all(self, execution_engine: ExecutionEngine): for (_, dtype_str), mref in self.gpu_memrefs.items(): dealloc_func = execution_engine.lookup("gpu_dealloc_" + dtype_str) @@ -105,10 +100,10 @@ def _deallocate_all(self, execution_engine: ExecutionEngine): self.gpu_memrefs = {} @contextmanager - def allocate(self, execution_engine: ExecutionEngine): + def allocate_inputs(self, execution_engine: ExecutionEngine): try: - self._allocate_inputs(execution_engine) - yield None + inputs = self._get_input_arrays(execution_engine) + yield inputs finally: self._deallocate_all(execution_engine) @@ -141,7 +136,7 @@ def _reference_solution(self) -> np.ndarray: raise NotImplementedError("Bias verification not implemented") return C_ref - def get_input_arrays( + def _get_input_arrays( self, execution_engine: ExecutionEngine ) -> list[ctypes.Structure]: A_gpu = self._allocate_array("A", self.a_shape, self.ab_type, execution_engine) diff --git a/python/examples/xegpu_matmul/runner.py b/python/examples/xegpu_matmul/runner.py index ff679e8..f9b0bbd 100644 --- a/python/examples/xegpu_matmul/runner.py +++ b/python/examples/xegpu_matmul/runner.py @@ -73,9 +73,8 @@ def execute( # get execution engine engine = get_engine(payload_module, requirements=workload.requirements()) - with workload.allocate(execution_engine=engine): + with workload.allocate_inputs(execution_engine=engine) as inputs: # prepare function arguments - inputs = workload.get_input_arrays(execution_engine=engine) pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] packed_args = get_packed_arg(pointers) @@ -150,8 +149,7 @@ def benchmark(*args): # get execution engine, rtclock requires mlir_c_runner engine = get_engine(payload_module) - with workload.allocate(execution_engine=engine): - inputs = workload.get_input_arrays(execution_engine=engine) + with workload.allocate_inputs(execution_engine=engine) as inputs: pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] if check_correctness: # call payload once to verify correctness From 01fcf33dd9a4b8b71734e8ef596f416eb6cb5d57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tuomas=20K=C3=A4rn=C3=A4?= Date: Thu, 27 Nov 2025 11:29:28 +0200 Subject: [PATCH 11/16] Update README Co-authored-by: Adam Siemieniuk --- python/examples/xegpu_matmul/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/examples/xegpu_matmul/README.md b/python/examples/xegpu_matmul/README.md index 0e4c96e..0070c62 100644 --- a/python/examples/xegpu_matmul/README.md +++ b/python/examples/xegpu_matmul/README.md @@ -49,7 +49,7 @@ If cmake cannot find LevelZero, set environment variable `LEVEL_ZERO_DIR= Date: Thu, 27 Nov 2025 11:37:53 +0200 Subject: [PATCH 12/16] mlir_utils: snake_case op names where possible --- python/examples/xegpu_matmul/mlir_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/examples/xegpu_matmul/mlir_utils.py b/python/examples/xegpu_matmul/mlir_utils.py index 863c337..7e9c3bf 100644 --- a/python/examples/xegpu_matmul/mlir_utils.py +++ b/python/examples/xegpu_matmul/mlir_utils.py @@ -17,7 +17,7 @@ def cse(op): def canonicalize(op): - with ir.InsertionPoint(transform.ApplyPatternsOp(op).patterns): + with ir.InsertionPoint(transform.apply_patterns(op).patterns): transform.ApplyCanonicalizationPatternsOp() From 32354fd772c07930df541fcc411a4280ce5ceba7 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 27 Nov 2025 11:48:39 +0200 Subject: [PATCH 13/16] remove element type options, only (f16,f16,f32) can be lowered --- python/examples/xegpu_matmul/matmul.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index 4ac7230..e010490 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -44,7 +44,7 @@ def __init__( M: int, N: int, K: int, - ab_type: str = "f32", + ab_type: str = "f16", c_type: str = "f32", has_bias: bool = False, has_relu: bool = False, @@ -55,6 +55,8 @@ def __init__( self.a_shape = (M, K) self.b_shape = (K, N) self.c_shape = (M, N) + assert ab_type == "f16", "Only f16 type is supported for A and B" + assert c_type == "f32", "Only f32 type is supported for C" self.ab_type = ab_type self.c_type = c_type type_str_to_numpy = { @@ -282,20 +284,6 @@ def parse_cli(): default=1, help="Number of initial prefetches.", ) - parser.add_argument( - "--ab-type", - type=str, - choices=["f16", "f32"], - default="f16", - help="Data type of A and B matrices.", - ) - parser.add_argument( - "--c-type", - type=str, - choices=["f16", "f32"], - default="f32", - help="Data type of the C matrix.", - ) parser.add_argument( "--nruns", type=int, @@ -359,14 +347,16 @@ def parse_cli(): } M, N, K = args.sizes + ab_type = "f16" + c_type = "f32" with ir.Context(), ir.Location.unknown(): wload = XeGPUMatMul( M=M, N=N, K=K, - ab_type=args.ab_type, - c_type=args.c_type, + ab_type=ab_type, + c_type=c_type, has_bias=False, has_relu=args.relu, ) @@ -397,7 +387,7 @@ def list2str(a): parts = [ f"sizes={list2str(args.sizes)}", - f"dt={args.ab_type},{args.c_type}", + f"dt={ab_type},{c_type}", f"wg-tile={list2str(args.wg_tile)}", f"sg-tile={list2str(args.sg_tile)}", f"k-tile={args.k_tile}", From 958d456a34b96fd387ddf9239000fbb8e3543f05 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 27 Nov 2025 15:50:59 +0200 Subject: [PATCH 14/16] enable simple CI test --- .github/workflows/examples.yml | 2 +- python/examples/xegpu_matmul/lit.local.cfg | 1 + python/examples/xegpu_matmul/matmul.py | 3 +++ 3 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 python/examples/xegpu_matmul/lit.local.cfg diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 3396fb1..b889a85 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -28,4 +28,4 @@ jobs: - name: Run lit-enabled examples as tests run: | export FILECHECK=FileCheck-18 # Ubuntu's llvm-dev appends a version number. - uv run lit python/examples # Makes sure to substitute FileCheck for $FILECHECK + uv run lit python/examples --verbose # Makes sure to substitute FileCheck for $FILECHECK diff --git a/python/examples/xegpu_matmul/lit.local.cfg b/python/examples/xegpu_matmul/lit.local.cfg new file mode 100644 index 0000000..378d891 --- /dev/null +++ b/python/examples/xegpu_matmul/lit.local.cfg @@ -0,0 +1 @@ +config.excludes = ["README.md", "mlir_utils.py", "payload.py", "runner.py", "schedule.py"] diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index e010490..6937b74 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -1,3 +1,6 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + """ XeGPU matrix multiplication benchmark. """ From c82ecd44b4e9382993f7d856931ec84bf9e4494b Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 28 Nov 2025 12:07:23 +0200 Subject: [PATCH 15/16] number of warmup iters is now configurable + clean up lit config --- python/examples/xegpu_matmul/lit.local.cfg | 2 +- python/examples/xegpu_matmul/matmul.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/examples/xegpu_matmul/lit.local.cfg b/python/examples/xegpu_matmul/lit.local.cfg index 378d891..b310830 100644 --- a/python/examples/xegpu_matmul/lit.local.cfg +++ b/python/examples/xegpu_matmul/lit.local.cfg @@ -1 +1 @@ -config.excludes = ["README.md", "mlir_utils.py", "payload.py", "runner.py", "schedule.py"] +config.excludes = ["mlir_utils.py", "payload.py", "runner.py", "schedule.py"] diff --git a/python/examples/xegpu_matmul/matmul.py b/python/examples/xegpu_matmul/matmul.py index 6937b74..32b397f 100644 --- a/python/examples/xegpu_matmul/matmul.py +++ b/python/examples/xegpu_matmul/matmul.py @@ -293,6 +293,12 @@ def parse_cli(): default=1000, help="Number of runs to average the execution time.", ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) parser.add_argument( "--relu", action="store_true", @@ -375,7 +381,7 @@ def parse_cli(): times = benchmark( wload, nruns=args.nruns, - nwarmup=15, + nwarmup=args.nwarmup, check_correctness=args.check_result, schedule_parameters=params, verbose=1, From b34570736f5634d53e42f06c9e02106115a1f1f9 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 28 Nov 2025 12:08:20 +0200 Subject: [PATCH 16/16] more snake_case transform ops --- python/examples/xegpu_matmul/mlir_utils.py | 8 +--- python/examples/xegpu_matmul/schedule.py | 49 ++++++++++++---------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/python/examples/xegpu_matmul/mlir_utils.py b/python/examples/xegpu_matmul/mlir_utils.py index 7e9c3bf..2b5a4ba 100644 --- a/python/examples/xegpu_matmul/mlir_utils.py +++ b/python/examples/xegpu_matmul/mlir_utils.py @@ -9,16 +9,12 @@ def apply_registered_pass(*args, **kwargs): def match(*args, **kwargs): - return structured.MatchOp(transform.AnyOpType.get(), *args, **kwargs) - - -def cse(op): - transform.ApplyCommonSubexpressionEliminationOp(op) + return structured.structured_match(transform.AnyOpType.get(), *args, **kwargs) def canonicalize(op): with ir.InsertionPoint(transform.apply_patterns(op).patterns): - transform.ApplyCanonicalizationPatternsOp() + transform.apply_patterns_canonicalization() def get_mlir_library_path(): diff --git a/python/examples/xegpu_matmul/schedule.py b/python/examples/xegpu_matmul/schedule.py index 86158e2..355177a 100644 --- a/python/examples/xegpu_matmul/schedule.py +++ b/python/examples/xegpu_matmul/schedule.py @@ -5,7 +5,7 @@ from mlir.dialects.bufferization import LayoutMapOption from mlir.dialects import transform from mlir.dialects.transform import structured -from mlir_utils import apply_registered_pass, match, cse, canonicalize +from mlir_utils import apply_registered_pass, match, canonicalize from typing import Optional @@ -139,28 +139,31 @@ def bundle_xepu_matmul_schedule( # wg tiling if has_relu: - terminal = match(mod, ops={"linalg.max"}).result + terminal = match(mod, ops={"linalg.max"}) elif has_bias: - terminal = match(mod, ops={"linalg.add"}).result + terminal = match(mod, ops={"linalg.add"}) else: - terminal = match(mod, ops={"linalg.matmul"}).result + terminal = match(mod, ops={"linalg.matmul"}) + # FIXME use structured.structured_fuse structured.FuseOp(terminal, tile_sizes=wg_tile, use_forall=True) - cse(mod) + transform.apply_cse(mod) canonicalize(mod) # k loop tiling - wg_matmul = match(mod, ops={"linalg.matmul"}).result + wg_matmul = match(mod, ops={"linalg.matmul"}) + # FIXME use structured.structured_tile_using_for wgk_matmul, k_loop = structured.TileUsingForOp( wg_matmul, sizes=[0, 0, k_tile] ).results - cse(func) + transform.apply_cse(func) canonicalize(func) if dump_kernel == "tiled": return mod, True # vectorize + # FIXME use structured.structured_vectorize_children_and_apply_patterns func = structured.VectorizeChildrenAndApplyPatternsOp( func, fold_type_extensions_into_contract=True, @@ -170,7 +173,7 @@ def bundle_xepu_matmul_schedule( k_loop = match(func, ops={"scf.for"}) loop.HoistLoopInvariantSubsetsOp(k_loop) - cse(func) + transform.apply_cse(func) canonicalize(func) if dump_kernel == "vectorized": @@ -189,7 +192,7 @@ def bundle_xepu_matmul_schedule( ).result # fold memref.subviews into vector.transfer_read/write ops mod = apply_registered_pass(mod, "fold-memref-alias-ops") - cse(mod) + transform.apply_cse(mod) canonicalize(mod) if dump_kernel == "bufferized": @@ -204,7 +207,7 @@ def bundle_xepu_matmul_schedule( func = apply_registered_pass(func, "gpu-map-parallel-loops") func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") func = apply_registered_pass(func, "lower-affine") - cse(func) + transform.apply_cse(func) canonicalize(func) # set correct number of gpu threads @@ -216,7 +219,7 @@ def bundle_xepu_matmul_schedule( canonicalize(func) func = apply_registered_pass(func, "gpu-launch-sink-index-computations") mod = apply_registered_pass(mod, "gpu-kernel-outlining") - cse(mod) + transform.apply_cse(mod) # set xevm target mod = apply_registered_pass( @@ -229,7 +232,7 @@ def bundle_xepu_matmul_schedule( gpu_mod = match(mod, ops={"gpu.module"}) gpu_func = match(gpu_mod, ops={"gpu.func"}) gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - cse(gpu_func) + transform.apply_cse(gpu_func) if dump_kernel == "xegpu-initial": return mod, True @@ -319,14 +322,14 @@ def convert_layout(value, input, target): if has_relu: # for post ops we need to add C layout manually - max_op = match(gpu_func, ops={"arith.maximumf"}).result + max_op = match(gpu_func, ops={"arith.maximumf"}) xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout) # find zero constant buffer and annotate it const_buffer = transform.get_producer_of_operand(anytype, max_op, 1) xegpu.set_op_layout_attr(const_buffer, result=True, index=0, **output_layout) if has_bias: # for post ops we need to add C layout manually - add_op = match(gpu_func, ops={"arith.addf"}).result + add_op = match(gpu_func, ops={"arith.addf"}) xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout) # annotate broadcast op operands @@ -350,14 +353,14 @@ def convert_layout(value, input, target): mask = transform.get_producer_of_operand(anytype, bcast_load, 2) xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1) raise NotImplementedError("Bias layout propagation is not supported.") - cse(gpu_func) + transform.apply_cse(gpu_func) canonicalize(gpu_func) # hoist desc ops out of reduction loop transform.apply_licm(k_loop) canonicalize(gpu_func) - cse(gpu_func) + transform.apply_cse(gpu_func) if dump_kernel == "xegpu-wg": return mod, True @@ -379,16 +382,16 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""): # xegpu distribution gpu_func = match(gpu_mod, ops={"gpu.func"}) gpu_func = apply_registered_pass(gpu_func, "xegpu-wg-to-sg-distribute") - cse(gpu_func) + transform.apply_cse(gpu_func) if dump_kernel == "xegpu-sg": return mod, True gpu_func = apply_registered_pass(gpu_func, "lower-affine") - cse(gpu_func) + transform.apply_cse(gpu_func) gpu_func = apply_registered_pass(gpu_func, "xegpu-blocking") canonicalize(gpu_func) - cse(gpu_func) + transform.apply_cse(gpu_func) if dump_kernel == "xegpu-inst": return mod, True @@ -396,16 +399,16 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""): gpu_func = apply_registered_pass(gpu_func, "xegpu-propagate-layout") gpu_mod = apply_registered_pass(gpu_mod, "xegpu-subgroup-distribute") canonicalize(gpu_mod) - cse(gpu_mod) + transform.apply_cse(gpu_mod) gpu_mod = apply_registered_pass(gpu_mod, "loop-invariant-code-motion") - cse(gpu_mod) + transform.apply_cse(gpu_mod) gpu_mod = apply_registered_pass(gpu_mod, "xegpu-vector-linearize") gpu_mod = apply_registered_pass(gpu_mod, "convert-xegpu-to-xevm") gpu_mod = apply_registered_pass( gpu_mod, "convert-gpu-to-llvm-spv", options={"use-64bit-index": "true"} ) gpu_mod = apply_registered_pass(gpu_mod, "convert-xevm-to-llvm") - cse(gpu_mod) + transform.apply_cse(gpu_mod) func = match(mod, ops={"func.func"}) func = apply_registered_pass(func, "gpu-async-region") @@ -424,7 +427,7 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""): mod = apply_registered_pass(mod, "gpu-to-llvm") mod = apply_registered_pass(mod, "lower-affine") mod = apply_registered_pass(mod, "reconcile-unrealized-casts") - cse(mod) + transform.apply_cse(mod) mod = apply_registered_pass(mod, "gpu-module-to-binary") return mod, False