Skip to content

Commit 9363023

Browse files
committed
add failure protection and log output when removing directories
1 parent de54e88 commit 9363023

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,29 @@ def __call__(self, model_path=None):
2525
params = inputs_params["weight_info"]
2626
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
2727

28-
model(**state_dict)
29-
compiled_model = torch.compile(model)
28+
# try to run the model
29+
try:
30+
model(**state_dict)
31+
except Exception as e:
32+
print(f"failed in running model:{e}")
33+
print(f"removing: {model_path}")
34+
shutil.rmtree(model_path)
35+
return False
36+
# try to compile the model
37+
try:
38+
compiled_model = torch.compile(model)
39+
except Exception as e:
40+
print(f"failed in compiling model:{e}")
41+
print(f"removing: {model_path}")
42+
shutil.rmtree(model_path)
43+
return False
3044
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
3145
if compiled_num_of_kernels == 1:
3246
print(model_path, "can be fully integrated")
3347
return True
3448
else:
35-
print(model_path, "can not be fully integrated")
49+
print(f"{model_path} can not be fully integrated, to be removed")
50+
print(f"removing: {model_path}")
3651
shutil.rmtree(model_path)
3752
return False
3853

0 commit comments

Comments
 (0)