From b95a6c99597467bb84f5e1a4b70063f1d8f054f8 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 22 Oct 2024 22:44:46 +0000 Subject: [PATCH] Dump SPIRV kernel in 'benchmark_driver.py' Signed-off-by: Anatoly Myachev --- .../benchmark_driver.py | 64 ++++++++++++++++++- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_driver.py b/benchmarks/triton_kernels_benchmark/benchmark_driver.py index 2e1ef40fdf..470c6c19e5 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_driver.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_driver.py @@ -399,19 +399,77 @@ def format_of(ty): return src +def serialize_kernel_metadata(arg, args_dict): + args_dict["num_warps"] = arg.num_warps + args_dict["threads_per_warp"] = arg.threads_per_warp + args_dict["shared_memory"] = arg.shared + args_dict["kernel_name"] = arg.name + args_dict["spv_name"] = f"{arg.name}.spv" + + +def serialize_args(args, constants, signature): + import numbers + dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS") + if not os.path.exists(dir_path): + os.makedirs(dir_path) + print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}") + + cnt = 0 + args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]} + args_dict["argument_list"] = [] + counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0} + cnt = 4 + for arg in args[cnt:]: + if type(arg).__name__ == "KernelMetadata": + serialize_kernel_metadata(arg, args_dict) + + if isinstance(arg, torch.Tensor): + cpu_tensor = arg.cpu() + tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt") + with open(tensor_path, "wb") as f: + torch.save(cpu_tensor, f) + new_arg = { + "name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype": + signature[counts["karg_cnt"]] + } + args_dict["argument_list"].append(new_arg) + counts["karg_cnt"] += 1 + counts["tensors"] += 1 + + if isinstance(arg, numbers.Number): + if counts["karg_cnt"] not in constants: + new_arg = { + "name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype": + signature[counts["karg_cnt"]] + } + args_dict["argument_list"].append(new_arg) + counts["karg_cnt"] += 1 + counts["scalars"] += 1 + cnt += 1 + # Dump argument info as a JSON file + json_path = os.path.join(dir_path, "args_data.json") + with open(json_path, "w", encoding="utf-8") as json_file: + import json + json.dump(args_dict, json_file, indent=4) + + class XPULauncher: def __init__(self, src, metadata): # pylint: disable=unused-argument ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else {} cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} - src = make_launcher(constants, signature, ids) + self.constants = {cst_key(key): value for key, value in constants.items()} + self.signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(self.constants, self.signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch def __call__(self, *args, **kwargs): + # Serialize KernelArguments for SPIR-V Runner + serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None) + if serialize_kernel_args: + serialize_args(args, self.constants, self.signature) self.launch(*args, **kwargs)