@@ -69,15 +69,11 @@ def load_class_from_file(
6969 module_name = file .stem
7070
7171 with open (file_path , "r" , encoding = "utf-8" ) as f :
72- original_code = f .read ()
73- if args .device == "cuda" :
74- modified_code = original_code .replace ("cpu" , "cuda" )
75- else :
76- modified_code = original_code
72+ model_code = f .read ()
7773 spec = importlib .util .spec_from_loader (module_name , loader = None )
7874 module = importlib .util .module_from_spec (spec )
7975 sys .modules [module_name ] = module
80- compiled_code = compile (modified_code , filename = file , mode = "exec" )
76+ compiled_code = compile (model_code , filename = file , mode = "exec" )
8177 exec (compiled_code , module .__dict__ )
8278
8379 model_class = getattr (module , class_name , None )
@@ -91,7 +87,10 @@ def get_compiler_backend(args) -> GraphCompilerBackend:
9187
9288def get_model (args ):
9389 model_class = load_class_from_file (args , class_name = "GraphModule" )
94- return model_class ().to (torch .device (args .device ))
90+ model = model_class ().to (torch .device (args .device ))
91+ # for param in model.parameters():
92+ # param.requires_grad_(False)
93+ return model
9594
9695
9796def get_input_dict (args ):
0 commit comments