Skip to content

Commit e92d660

Browse files
committed
Update
1 parent 45c8e5a commit e92d660

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

graph_net/torch/test_compiler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,11 @@ def load_class_from_file(
6969
module_name = file.stem
7070

7171
with open(file_path, "r", encoding="utf-8") as f:
72-
original_code = f.read()
73-
if args.device == "cuda":
74-
modified_code = original_code.replace("cpu", "cuda")
75-
else:
76-
modified_code = original_code
72+
model_code = f.read()
7773
spec = importlib.util.spec_from_loader(module_name, loader=None)
7874
module = importlib.util.module_from_spec(spec)
7975
sys.modules[module_name] = module
80-
compiled_code = compile(modified_code, filename=file, mode="exec")
76+
compiled_code = compile(model_code, filename=file, mode="exec")
8177
exec(compiled_code, module.__dict__)
8278

8379
model_class = getattr(module, class_name, None)
@@ -91,7 +87,10 @@ def get_compiler_backend(args) -> GraphCompilerBackend:
9187

9288
def get_model(args):
9389
model_class = load_class_from_file(args, class_name="GraphModule")
94-
return model_class().to(torch.device(args.device))
90+
model = model_class().to(torch.device(args.device))
91+
# for param in model.parameters():
92+
# param.requires_grad_(False)
93+
return model
9594

9695

9796
def get_input_dict(args):

0 commit comments

Comments
 (0)