Skip to content

Commit 6ce2c4e

Browse files
committed
Update
1 parent 117d510 commit 6ce2c4e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

graph_net/torch/test_compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def load_class_from_file(
6565
spec = importlib.util.spec_from_loader(module_name, loader=None)
6666
module = importlib.util.module_from_spec(spec)
6767
sys.modules[module_name] = module
68-
compiled_code = compile(model_code, filename=file, mode="exec")
68+
compiled_code = compile(cleaned_code, filename=file, mode="exec")
6969
exec(compiled_code, module.__dict__)
7070

7171
model_class = getattr(module, class_name, None)
@@ -88,6 +88,9 @@ def get_model(args):
8888
def get_input_dict(args):
8989
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
9090
params = inputs_params["weight_info"]
91+
for tensor_meta in params.values():
92+
if hasattr(tensor_meta, "device"):
93+
tensor_meta.device = args.device
9194
return {
9295
k: utils.replay_tensor(v).to(torch.device(args.device))
9396
for k, v in params.items()

0 commit comments

Comments
 (0)