Skip to content

Commit 8910746

Browse files
authored
workaroud for prebuild error (ROCm#1588)
* workaroud for prebuild error * update * fix lint
1 parent 5cc97aa commit 8910746

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

csrc/ck_deepgemm/gen_instances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, working_path, istune=False):
1919
# self.b_dtype = b_dtype.upper()
2020
# self.c_dtype = c_dtype.upper()
2121
# self.quant_type = quant_type
22-
assert (istune == False, "not surpport tuning!")
22+
assert istune == False, "not surpport tuning!"
2323

2424
def gen_instance(self, k: kernelInstance):
2525
INSTANCE_IMPL = f"""// SPDX-License-Identifier: MIT

csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,16 @@ def get_tune_dict(tune_dict_csv):
236236
gpu = torch.cuda.current_device()
237237
device_properties = torch.cuda.get_device_properties(gpu)
238238
cu_num = device_properties.multi_processor_count
239-
tune_df = tune_df[
240-
(tune_df["cu_num"] == cu_num) & (tune_df["libtype"] == "ck")
241-
].reset_index()
239+
tune_df = tune_df[(tune_df["cu_num"] == cu_num)].reset_index()
240+
tune_df = tune_df[tune_df["libtype"] == "ck"].reset_index()
242241
for i in range(len(tune_df)):
243242
M = tune_df.loc[i, "M"]
244243
N = tune_df.loc[i, "N"]
245244
K = tune_df.loc[i, "K"]
246245
kid = tune_df.loc[i, "kernelId"]
246+
if kid < 0 or kid >= len(kernels_list):
247+
print(f"[Warning]: kernelId {kid} is out of range, skip it")
248+
continue
247249
tune_dict[(M, N, K)] = kernels_list[kid]
248250
return tune_dict
249251

csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,16 @@ def get_tune_dict(tune_dict_csv):
233233
gpu = torch.cuda.current_device()
234234
device_properties = torch.cuda.get_device_properties(gpu)
235235
cu_num = device_properties.multi_processor_count
236-
tune_df = tune_df[
237-
(tune_df["cu_num"] == cu_num) & (tune_df["libtype"] == "cktile")
238-
].reset_index()
236+
tune_df = tune_df[(tune_df["cu_num"] == cu_num)].reset_index()
237+
tune_df = tune_df[tune_df["libtype"] == "cktile"].reset_index()
239238
for i in range(len(tune_df)):
240239
M = tune_df.loc[i, "M"]
241240
N = tune_df.loc[i, "N"]
242241
K = tune_df.loc[i, "K"]
243242
kid = tune_df.loc[i, "kernelId"]
243+
if kid < 0 or kid > len(kernels_list):
244+
print(f"[Warning]: kernelId {kid} is out of range, skip it")
245+
continue
244246
tune_dict[(M, N, K)] = kernels_list[kid]
245247
return tune_dict
246248

0 commit comments

Comments
 (0)