1010from  typing  import  Any , Generator , Iterable , List , Optional , Tuple , cast 
1111
1212import  torch 
13- import  torch .distributed  as  dist 
1413import  torch .nn  as  nn 
1514from  safetensors .torch  import  load_file  as  safetensors_load_file 
1615from  transformers  import  AutoImageProcessor , AutoTokenizer 
2120from  fastvideo .v1 .fastvideo_args  import  FastVideoArgs 
2221from  fastvideo .v1 .logger  import  init_logger 
2322from  fastvideo .v1 .models .hf_transformer_utils  import  get_diffusers_config 
24- from  fastvideo .v1 .models .loader .fsdp_load  import  (init_device_mesh ,
25-                                                   maybe_load_fsdp_model ,
26-                                                   shard_model )
23+ from  fastvideo .v1 .models .loader .fsdp_load  import  maybe_load_fsdp_model 
2724from  fastvideo .v1 .models .loader .utils  import  set_default_torch_dtype 
2825from  fastvideo .v1 .models .loader .weight_utils  import  (
2926    filter_duplicate_safetensors_files , filter_files_not_needed_for_inference ,
@@ -166,19 +163,16 @@ def _prepare_weights(
166163        return  hf_folder , hf_weights_files , use_safetensors 
167164
168165    def  _get_weights_iterator (
169-             self ,
170-             source : "Source" ,
171-             to_cpu : bool  =  True 
166+             self , source : "Source" 
172167    ) ->  Generator [Tuple [str , torch .Tensor ], None , None ]:
173168        """Get an iterator for the model weights based on the load format.""" 
174169        hf_folder , hf_weights_files , use_safetensors  =  self ._prepare_weights (
175170            source .model_or_path , source .fall_back_to_pt ,
176171            source .allow_patterns_overrides )
177172        if  use_safetensors :
178-             weights_iterator  =  safetensors_weights_iterator (
179-                 hf_weights_files , to_cpu )
173+             weights_iterator  =  safetensors_weights_iterator (hf_weights_files )
180174        else :
181-             weights_iterator  =  pt_weights_iterator (hf_weights_files ,  to_cpu )
175+             weights_iterator  =  pt_weights_iterator (hf_weights_files )
182176
183177        if  self .counter_before_loading_weights  ==  0.0 :
184178            self .counter_before_loading_weights  =  time .perf_counter ()
@@ -187,11 +181,10 @@ def _get_weights_iterator(
187181                for  (name , tensor ) in  weights_iterator )
188182
189183    def  _get_all_weights (
190-             self ,
191-             model_config : Any ,
192-             model : nn .Module ,
193-             model_path : str ,
194-             to_cpu : bool  =  True 
184+         self ,
185+         model_config : Any ,
186+         model : nn .Module ,
187+         model_path : str ,
195188    ) ->  Generator [Tuple [str , torch .Tensor ], None , None ]:
196189        primary_weights  =  TextEncoderLoader .Source (
197190            model_path ,
@@ -200,14 +193,14 @@ def _get_all_weights(
200193            allow_patterns_overrides = getattr (model , "allow_patterns_overrides" ,
201194                                             None ),
202195        )
203-         yield  from  self ._get_weights_iterator (primary_weights ,  to_cpu )
196+         yield  from  self ._get_weights_iterator (primary_weights )
204197
205198        secondary_weights  =  cast (
206199            Iterable [TextEncoderLoader .Source ],
207200            getattr (model , "secondary_weights" , ()),
208201        )
209202        for  source  in  secondary_weights :
210-             yield  from  self ._get_weights_iterator (source ,  to_cpu )
203+             yield  from  self ._get_weights_iterator (source )
211204
212205    def  load (self , model_path : str , architecture : str ,
213206             fastvideo_args : FastVideoArgs ):
@@ -243,19 +236,13 @@ def load(self, model_path: str, architecture: str,
243236        target_device  =  get_local_torch_device ()
244237        # TODO(will): add support for other dtypes 
245238        return  self .load_model (model_path , encoder_config , target_device ,
246-                                fastvideo_args ,  encoder_precision )
239+                                encoder_precision )
247240
248241    def  load_model (self ,
249242                   model_path : str ,
250243                   model_config : EncoderConfig ,
251244                   target_device : torch .device ,
252-                    fastvideo_args : FastVideoArgs ,
253245                   dtype : str  =  "fp16" ):
254-         use_cpu_offload  =  fastvideo_args .text_encoder_offload  and  len (
255-             getattr (model_config , "_fsdp_shard_conditions" , [])) >  0 
256- 
257-         if  fastvideo_args .text_encoder_offload :
258-             target_device  =  torch .device ("cpu" )
259246        with  set_default_torch_dtype (PRECISION_TO_TYPE [dtype ]):
260247            with  target_device :
261248                architectures  =  getattr (model_config , "architectures" , [])
@@ -264,26 +251,12 @@ def load_model(self,
264251
265252            weights_to_load  =  {name  for  name , _  in  model .named_parameters ()}
266253            loaded_weights  =  model .load_weights (
267-                 self ._get_all_weights (model_config , model , model_path ,
268-                                       use_cpu_offload ))
254+                 self ._get_all_weights (model_config , model , model_path ))
269255            self .counter_after_loading_weights  =  time .perf_counter ()
270256            logger .info (
271257                "Loading weights took %.2f seconds" ,
272258                self .counter_after_loading_weights  - 
273259                self .counter_before_loading_weights )
274- 
275-             if  use_cpu_offload :
276-                 mesh  =  init_device_mesh (
277-                     "cuda" ,
278-                     mesh_shape = (1 , dist .get_world_size ()),
279-                     mesh_dim_names = ("offload" , "replicate" ),
280-                 )
281-                 shard_model (model ,
282-                             cpu_offload = True ,
283-                             reshard_after_forward = True ,
284-                             mesh = mesh ["offload" ],
285-                             fsdp_shard_conditions = model ._fsdp_shard_conditions ,
286-                             pin_cpu_memory = fastvideo_args .pin_cpu_memory )
287260            # We only enable strict check for non-quantized models 
288261            # that have loaded weights tracking currently. 
289262            # if loaded_weights is not None: 
@@ -320,7 +293,7 @@ def load(self, model_path: str, architecture: str,
320293        target_device  =  get_local_torch_device ()
321294        # TODO(will): add support for other dtypes 
322295        return  self .load_model (
323-             model_path , encoder_config , target_device ,  fastvideo_args , 
296+             model_path , encoder_config , target_device ,
324297            fastvideo_args .pipeline_config .image_encoder_precision )
325298
326299
@@ -567,4 +540,4 @@ def load_module(module_name: str, component_model_path: str,
567540                                                 transformers_or_diffusers )
568541
569542        # Load the module 
570-         return  loader .load (component_model_path , architecture , fastvideo_args )
543+         return  loader .load (component_model_path , architecture , fastvideo_args )
0 commit comments