Skip to content

Commit 2213528

Browse files
authored
Fix Key Error in kernel args serialization (#3182)
Enable `TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS` may encounter python key error. The reason is that the way of accessing dict `signature` is out-of-date. This PR just updates `signature` accessing.
1 parent 686a8c1 commit 2213528

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

third_party/intel/backend/driver.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -548,36 +548,40 @@ def serialize_args(args, constants, signature):
548548

549549
cnt = 0
550550
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
551+
# 3: stream
552+
# 4: function
553+
# 5: packed kernel metadata
554+
assert type(args[cnt + 5]).__name__ == "KernelMetadata"
555+
serialize_kernel_metadata(args[cnt + 5], args_dict)
556+
# 6: launch_metadata
557+
# 7: launch_enter_hook
558+
# 8: launch_exit_hook
551559
args_dict['argument_list'] = []
552560
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
553-
cnt = 4
561+
cnt += 9
554562
for arg in args[cnt:]:
555-
if type(arg).__name__ == "KernelMetadata":
556-
serialize_kernel_metadata(arg, args_dict)
557-
563+
sig_name = list(signature.keys())[counts['karg_cnt']]
558564
if isinstance(arg, torch.Tensor):
559565
cpu_tensor = arg.cpu()
560566
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
561567
with open(tensor_path, 'wb') as f:
562568
torch.save(cpu_tensor, f)
563569
new_arg = {
564570
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
565-
signature[counts['karg_cnt']]
571+
signature[sig_name]
566572
}
567573
args_dict['argument_list'].append(new_arg)
568-
counts['karg_cnt'] += 1
569574
counts['tensors'] += 1
570-
571575
if isinstance(arg, numbers.Number):
572-
if counts['karg_cnt'] not in constants:
576+
if (counts['karg_cnt'], ) not in constants.keys():
573577
new_arg = {
574-
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
575-
signature[counts['karg_cnt']]
578+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": arg, "ctype":
579+
signature[sig_name]
576580
}
577581
args_dict['argument_list'].append(new_arg)
578-
counts['karg_cnt'] += 1
579582
counts['scalars'] += 1
580-
cnt += 1
583+
counts['karg_cnt'] += 1
584+
581585
# Dump argument info as a JSON file
582586
json_path = os.path.join(dir_path, 'args_data.json')
583587
with open(json_path, 'w') as json_file:

utils/SPIRVRunner/SPIRVRunner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ void set_argument(sycl::handler &cgh, int index, ordered_json &item) {
250250
} else if (type == "u64") {
251251
auto val = item.at("value").get<uint64_t>();
252252
set_scalar_arg<uint64_t>(cgh, index, &val);
253-
} else if (type == "fp32" || type == "fp32" || type == "f32") {
253+
} else if (type == "fp32" || type == "f32") {
254254
auto val = item.at("value").get<float>();
255255
set_scalar_arg<float>(cgh, index, &val);
256256
} else if (type == "fp64") {

0 commit comments

Comments
 (0)