|
| 1 | +import numpy as np |
| 2 | +import ctypes |
| 3 | +import os |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +from mlir.dialects.transform import interpreter as transform_interpreter |
| 7 | +from mlir.dialects import func, arith, scf, memref |
| 8 | +from mlir.execution_engine import ExecutionEngine |
| 9 | +from mlir import ir |
| 10 | +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor |
| 11 | + |
| 12 | +from lighthouse.utils import get_packed_arg |
| 13 | +from mlir_utils import get_mlir_library_path |
| 14 | + |
| 15 | + |
| 16 | +def get_engine(payload_module, opt_level=3) -> ExecutionEngine: |
| 17 | + context = ir.Context() |
| 18 | + location = ir.Location.unknown(context) |
| 19 | + lib_dir = get_mlir_library_path() |
| 20 | + libs = [ |
| 21 | + "libmlir_levelzero_runtime.so", |
| 22 | + "libmlir_runner_utils.so", |
| 23 | + "libmlir_c_runner_utils.so" |
| 24 | + ] |
| 25 | + libs = [os.path.join(lib_dir, lib) for lib in libs] |
| 26 | + with context, location: |
| 27 | + execution_engine = ExecutionEngine( |
| 28 | + payload_module, opt_level=opt_level, shared_libs=libs |
| 29 | + ) |
| 30 | + execution_engine.initialize() |
| 31 | + return execution_engine |
| 32 | + |
| 33 | + |
| 34 | +def apply_transform_schedule( |
| 35 | + payload_module, |
| 36 | + schedule_module, |
| 37 | + context, |
| 38 | + location, |
| 39 | + dump_kernel: Optional[str] = None, |
| 40 | + dump_schedule: bool = False, |
| 41 | +): |
| 42 | + if not dump_kernel or dump_kernel != "initial": |
| 43 | + with context, location: |
| 44 | + # invoke transform interpreter directly |
| 45 | + transform_interpreter.apply_named_sequence( |
| 46 | + payload_root=payload_module, |
| 47 | + transform_root=schedule_module.body.operations[0], |
| 48 | + transform_module=schedule_module, |
| 49 | + ) |
| 50 | + if dump_kernel: |
| 51 | + print(payload_module) |
| 52 | + if dump_schedule: |
| 53 | + print(schedule_module) |
| 54 | + |
| 55 | + |
| 56 | +def lower_payload( |
| 57 | + workload, |
| 58 | + dump_kernel: Optional[str] = None, |
| 59 | + dump_schedule: bool = False, |
| 60 | + schedule_parameters: Optional[dict] = None, |
| 61 | +) -> ir.Module: |
| 62 | + payload_module = workload.payload_module() |
| 63 | + schedule_module = workload.schedule_module( |
| 64 | + dump_kernel=dump_kernel, parameters=schedule_parameters |
| 65 | + ) |
| 66 | + apply_transform_schedule( |
| 67 | + payload_module, |
| 68 | + schedule_module, |
| 69 | + workload.context, |
| 70 | + workload.location, |
| 71 | + dump_kernel=dump_kernel, |
| 72 | + dump_schedule=dump_schedule, |
| 73 | + ) |
| 74 | + return payload_module |
| 75 | + |
| 76 | + |
| 77 | +def execute( |
| 78 | + workload, |
| 79 | + check_correctness: bool = True, |
| 80 | + schedule_parameters: Optional[dict] = None, |
| 81 | + verbose: int = 0, |
| 82 | +): |
| 83 | + # lower payload with schedule |
| 84 | + payload_module = lower_payload( |
| 85 | + workload, schedule_parameters=schedule_parameters |
| 86 | + ) |
| 87 | + # get execution engine |
| 88 | + engine = get_engine( |
| 89 | + payload_module, requirements=workload.requirements() |
| 90 | + ) |
| 91 | + |
| 92 | + with workload.allocate(execution_engine=engine): |
| 93 | + # prepare function arguments |
| 94 | + inputs = workload.get_input_arrays(execution_engine=engine) |
| 95 | + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] |
| 96 | + packed_args = get_packed_arg(pointers) |
| 97 | + |
| 98 | + # handle to payload function |
| 99 | + payload_func = engine.lookup(workload.payload_function_name) |
| 100 | + |
| 101 | + # call |
| 102 | + payload_func(packed_args) |
| 103 | + |
| 104 | + if check_correctness: |
| 105 | + workload.check_correctness( |
| 106 | + execution_engine=engine, verbose=verbose |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def benchmark( |
| 111 | + workload, |
| 112 | + nruns: int = 100, |
| 113 | + nwarmup: int = 10, |
| 114 | + schedule_parameters: Optional[dict] = None, |
| 115 | + check_correctness: bool = True, |
| 116 | + verbose: int = 0, |
| 117 | +) -> np.ndarray: |
| 118 | + |
| 119 | + # get original payload module |
| 120 | + payload_module = workload.payload_module() |
| 121 | + |
| 122 | + # find payload function |
| 123 | + payload_func = None |
| 124 | + for op in payload_module.operation.regions[0].blocks[0]: |
| 125 | + if (isinstance(op, func.FuncOp) and |
| 126 | + str(op.name).strip('"') == workload.payload_function_name): |
| 127 | + payload_func = op |
| 128 | + break |
| 129 | + assert payload_func is not None, "Could not find payload function" |
| 130 | + payload_arguments = payload_func.type.inputs |
| 131 | + |
| 132 | + # emit benchmark function that calls payload and times it |
| 133 | + with workload.context, workload.location: |
| 134 | + with ir.InsertionPoint(payload_module.body): |
| 135 | + # define rtclock function |
| 136 | + f64_t = ir.F64Type.get() |
| 137 | + f = func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") |
| 138 | + # emit benchmark function |
| 139 | + time_memref_t = ir.MemRefType.get((nruns,), f64_t) |
| 140 | + args = payload_arguments + [time_memref_t] |
| 141 | + f = func.FuncOp("benchmark", (tuple(args), ())) |
| 142 | + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() |
| 143 | + with ir.InsertionPoint(f.add_entry_block()): |
| 144 | + index_t = ir.IndexType.get() |
| 145 | + zero = arith.ConstantOp(index_t, 0) |
| 146 | + one = arith.ConstantOp(index_t, 1) |
| 147 | + nwarmup_cst = arith.ConstantOp(index_t, nwarmup) |
| 148 | + for_op = scf.ForOp(zero, nwarmup_cst, one) |
| 149 | + with ir.InsertionPoint(for_op.body): |
| 150 | + func.CallOp( |
| 151 | + payload_func, list(f.arguments[:len(payload_arguments)]) |
| 152 | + ) |
| 153 | + scf.YieldOp(()) |
| 154 | + nruns_cst = arith.ConstantOp(index_t, nruns) |
| 155 | + for_op = scf.ForOp(zero, nruns_cst, one) |
| 156 | + i = for_op.induction_variable |
| 157 | + with ir.InsertionPoint(for_op.body): |
| 158 | + tic = func.CallOp((f64_t,), "rtclock", ()).result |
| 159 | + func.CallOp( |
| 160 | + payload_func, list(f.arguments[:len(payload_arguments)]) |
| 161 | + ) |
| 162 | + toc = func.CallOp((f64_t,), "rtclock", ()).result |
| 163 | + time = arith.SubFOp(toc, tic) |
| 164 | + memref.StoreOp(time, f.arguments[-1], [i]) |
| 165 | + scf.YieldOp(()) |
| 166 | + func.ReturnOp(()) |
| 167 | + |
| 168 | + # lower |
| 169 | + apply_transform_schedule( |
| 170 | + payload_module, |
| 171 | + workload.schedule_module(parameters=schedule_parameters), |
| 172 | + workload.context, |
| 173 | + workload.location, |
| 174 | + ) |
| 175 | + # get execution engine, rtclock requires mlir_c_runner |
| 176 | + engine = get_engine(payload_module) |
| 177 | + |
| 178 | + with workload.allocate(execution_engine=engine): |
| 179 | + inputs = workload.get_input_arrays(execution_engine=engine) |
| 180 | + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] |
| 181 | + if check_correctness: |
| 182 | + # call payload once to verify correctness |
| 183 | + # prepare function arguments |
| 184 | + packed_args = get_packed_arg(pointers) |
| 185 | + |
| 186 | + payload_func = engine.lookup(workload.payload_function_name) |
| 187 | + payload_func(packed_args) |
| 188 | + success = workload.check_correctness( |
| 189 | + execution_engine=engine, verbose=verbose |
| 190 | + ) |
| 191 | + if not success: |
| 192 | + raise ValueError("Benchmark verification failed.") |
| 193 | + |
| 194 | + # allocate buffer for timings and prepare arguments |
| 195 | + time_array = np.zeros((nruns,), dtype=np.float64) |
| 196 | + time_memref = get_ranked_memref_descriptor(time_array) |
| 197 | + time_pointer = ctypes.pointer(ctypes.pointer(time_memref)) |
| 198 | + packed_args_with_time = get_packed_arg(pointers + [time_pointer]) |
| 199 | + |
| 200 | + # call benchmark function |
| 201 | + benchmark_func = engine.lookup("benchmark") |
| 202 | + benchmark_func(packed_args_with_time) |
| 203 | + |
| 204 | + return time_array |
0 commit comments