Skip to content

Commit 61fd54d

Browse files
authored
Dump SPIRV kernel in benchmark_driver.py (#2535)
Closes #2536 Mostly copying from #2258 with minor code style adjustments. Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 5529b87 commit 61fd54d

File tree

1 file changed

+61
-3
lines changed

1 file changed

+61
-3
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,19 +399,77 @@ def format_of(ty):
399399
return src
400400

401401

402+
def serialize_kernel_metadata(arg, args_dict):
403+
args_dict["num_warps"] = arg.num_warps
404+
args_dict["threads_per_warp"] = arg.threads_per_warp
405+
args_dict["shared_memory"] = arg.shared
406+
args_dict["kernel_name"] = arg.name
407+
args_dict["spv_name"] = f"{arg.name}.spv"
408+
409+
410+
def serialize_args(args, constants, signature):
411+
import numbers
412+
dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS")
413+
if not os.path.exists(dir_path):
414+
os.makedirs(dir_path)
415+
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
416+
417+
cnt = 0
418+
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
419+
args_dict["argument_list"] = []
420+
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
421+
cnt = 4
422+
for arg in args[cnt:]:
423+
if type(arg).__name__ == "KernelMetadata":
424+
serialize_kernel_metadata(arg, args_dict)
425+
426+
if isinstance(arg, torch.Tensor):
427+
cpu_tensor = arg.cpu()
428+
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
429+
with open(tensor_path, "wb") as f:
430+
torch.save(cpu_tensor, f)
431+
new_arg = {
432+
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
433+
signature[counts["karg_cnt"]]
434+
}
435+
args_dict["argument_list"].append(new_arg)
436+
counts["karg_cnt"] += 1
437+
counts["tensors"] += 1
438+
439+
if isinstance(arg, numbers.Number):
440+
if counts["karg_cnt"] not in constants:
441+
new_arg = {
442+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
443+
signature[counts["karg_cnt"]]
444+
}
445+
args_dict["argument_list"].append(new_arg)
446+
counts["karg_cnt"] += 1
447+
counts["scalars"] += 1
448+
cnt += 1
449+
# Dump argument info as a JSON file
450+
json_path = os.path.join(dir_path, "args_data.json")
451+
with open(json_path, "w", encoding="utf-8") as json_file:
452+
import json
453+
json.dump(args_dict, json_file, indent=4)
454+
455+
402456
class XPULauncher:
403457

404458
def __init__(self, src, metadata): # pylint: disable=unused-argument
405459
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
406460
constants = src.constants if hasattr(src, "constants") else {}
407461
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
408-
constants = {cst_key(key): value for key, value in constants.items()}
409-
signature = {cst_key(key): value for key, value in src.signature.items()}
410-
src = make_launcher(constants, signature, ids)
462+
self.constants = {cst_key(key): value for key, value in constants.items()}
463+
self.signature = {cst_key(key): value for key, value in src.signature.items()}
464+
src = make_launcher(self.constants, self.signature, ids)
411465
mod = compile_module_from_src(src, "__triton_launcher")
412466
self.launch = mod.launch
413467

414468
def __call__(self, *args, **kwargs):
469+
# Serialize KernelArguments for SPIR-V Runner
470+
serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None)
471+
if serialize_kernel_args:
472+
serialize_args(args, self.constants, self.signature)
415473
self.launch(*args, **kwargs)
416474

417475

0 commit comments

Comments
 (0)