1818import  inspect 
1919import  itertools 
2020import  json 
21+ import  time 
2122import  os 
2223import  re 
2324from  collections  import  OrderedDict 
@@ -538,6 +539,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
538539        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. 
539540        ``` 
540541        """ 
542+         init  =  time .time ()
541543        cache_dir  =  kwargs .pop ("cache_dir" , None )
542544        ignore_mismatched_sizes  =  kwargs .pop ("ignore_mismatched_sizes" , False )
543545        force_download  =  kwargs .pop ("force_download" , False )
@@ -637,6 +639,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
637639        }
638640
639641        # load config 
642+         start  =  time .time ()
640643        config , unused_kwargs , commit_hash  =  cls .load_config (
641644            config_path ,
642645            cache_dir = cache_dir ,
@@ -651,6 +654,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
651654            user_agent = user_agent ,
652655            ** kwargs ,
653656        )
657+         peint (time .time ()- start ,"config load" )
654658        # no in-place modification of the original config. 
655659        config  =  copy .deepcopy (config )
656660
@@ -816,6 +820,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
816820                )
817821
818822            if  low_cpu_mem_usage :
823+                 print ("low cpu mem use" )
819824                # Instantiate model with empty weights 
820825                with  accelerate .init_empty_weights ():
821826                    model  =  cls .from_config (config , ** unused_kwargs )
@@ -829,6 +834,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
829834                if  device_map  is  None  and  not  is_sharded :
830835                    # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. 
831836                    # It would error out during the `validate_environment()` call above in the absence of cuda. 
837+                     
832838                    is_quant_method_bnb  =  (
833839                        getattr (model , "quantization_method" , None ) ==  QuantizationMethod .BITS_AND_BYTES 
834840                    )
@@ -874,6 +880,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
874880                else :  # else let accelerate handle loading and dispatching. 
875881                    # Load weights and dispatch according to the device_map 
876882                    # by default the device_map is None and the weights are loaded on the CPU 
883+                     print ("accelerate load checkpoint" )
877884                    force_hook  =  True 
878885                    device_map  =  _determine_device_map (
879886                        model , device_map , max_memory , torch_dtype , keep_in_fp32_modules , hf_quantizer 
@@ -935,11 +942,23 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
935942                    "error_msgs" : [],
936943                }
937944            else :
945+                 print ("did from_config" )
946+                 print (cls )
947+                 torch .cuda .synchronzie ("cuda" )
948+                 start  =  time .time ()
938949                model  =  cls .from_config (config , ** unused_kwargs )
950+                 torch .cuda .synchronize ("cuda" )
951+                 print (time .time ()- start ,"from_config" )
939952
953+                 torch .cuda .synchronzie ("cuda" )
954+                 start  =  time .time ()
940955                state_dict  =  load_state_dict (model_file , variant = variant )
956+                 torch .cuda .synchronize ("cuda" )
957+                 print (time .time ()- start ,"load_state_dict" )
941958                model ._convert_deprecated_attention_blocks (state_dict )
942959
960+                 torch .cuda .synchronzie ("cuda" )
961+                 start  =  time .time ()
943962                model , missing_keys , unexpected_keys , mismatched_keys , error_msgs  =  cls ._load_pretrained_model (
944963                    model ,
945964                    state_dict ,
@@ -948,13 +967,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
948967                    ignore_mismatched_sizes = ignore_mismatched_sizes ,
949968                )
950969
970+                 torch .cuda .synchronize ("cuda" )
971+                 print (time .time ()- start ,"load_ pretrained_model" )
972+ 
951973                loading_info  =  {
952974                    "missing_keys" : missing_keys ,
953975                    "unexpected_keys" : unexpected_keys ,
954976                    "mismatched_keys" : mismatched_keys ,
955977                    "error_msgs" : error_msgs ,
956978                }
957979
980+         torch .cuda .synchronize ("cuda" )
981+         start =  time .time ()
958982        if  hf_quantizer  is  not   None :
959983            hf_quantizer .postprocess_model (model )
960984            model .hf_quantizer  =  hf_quantizer 
@@ -975,11 +999,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
975999        else :
9761000            model .register_to_config (_name_or_path = pretrained_model_name_or_path )
9771001
1002+         torch .cuda .synchronize ("cuda" )
1003+         print (time .time ()- start ,"to device" )
1004+ 
9781005        # Set model in evaluation mode to deactivate DropOut modules by default 
9791006        model .eval ()
9801007        if  output_loading_info :
9811008            return  model , loading_info 
9821009
1010+         print (time .time ()- init ,"total" )
9831011        return  model 
9841012
9851013    # Adapted from `transformers`. 
0 commit comments