File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -65,7 +65,7 @@ def load_class_from_file(
6565 spec = importlib .util .spec_from_loader (module_name , loader = None )
6666 module = importlib .util .module_from_spec (spec )
6767 sys .modules [module_name ] = module
68- compiled_code = compile (model_code , filename = file , mode = "exec" )
68+ compiled_code = compile (cleaned_code , filename = file , mode = "exec" )
6969 exec (compiled_code , module .__dict__ )
7070
7171 model_class = getattr (module , class_name , None )
@@ -88,6 +88,9 @@ def get_model(args):
8888def get_input_dict (args ):
8989 inputs_params = utils .load_converted_from_text (f"{ args .model_path } " )
9090 params = inputs_params ["weight_info" ]
91+ for tensor_meta in params .values ():
92+ if hasattr (tensor_meta , "device" ):
93+ tensor_meta .device = args .device
9194 return {
9295 k : utils .replay_tensor (v ).to (torch .device (args .device ))
9396 for k , v in params .items ()
You can’t perform that action at this time.
0 commit comments