@@ -94,6 +94,12 @@ def __init__(self,
9494                self .initialize_validation_pipeline (self .training_args )
9595            self .initialize_training_pipeline (self .training_args )
9696
97+         if  fastvideo_args .distill_mode :
98+             self .initialize_distillation_pipeline (fastvideo_args )
99+ 
100+         if  fastvideo_args .log_validation :
101+             self .initialize_validation_pipeline (fastvideo_args )
102+ 
97103        self .initialize_pipeline (fastvideo_args )
98104
99105        if  not  fastvideo_args .training_mode :
@@ -109,6 +115,10 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs):
109115            "if log_validation is True, the pipeline must implement this method" 
110116        )
111117
118+     def  initialize_distillation_pipeline (self , fastvideo_args : FastVideoArgs ):
119+         raise  NotImplementedError (
120+             "if distill_mode is True, the pipeline must implement this method" )
121+ 
112122    @classmethod  
113123    def  from_pretrained (cls ,
114124                        model_path : str ,
@@ -148,7 +158,7 @@ def from_pretrained(cls,
148158            config_args  =  shallow_asdict (config )
149159            config_args .update (kwargs )
150160
151-         if  args   is   None   or   args . inference_mode :
161+         if  args . mode   ==   "inference" :
152162            fastvideo_args  =  FastVideoArgs (model_path = model_path ,
153163                                           device_str = device  or  "cuda"  if 
154164                                           torch .cuda .is_available () else  "cpu" ,
@@ -172,7 +182,7 @@ def from_pretrained(cls,
172182            fastvideo_args .num_gpus  =  int (os .environ .get ("WORLD_SIZE" , 1 ))
173183            fastvideo_args .use_cpu_offload  =  False 
174184            # make sure we are in training mode 
175-             fastvideo_args .inference_mode  =  False 
185+             fastvideo_args .mode  =  args . mode 
176186            # we hijack the precision to be the master weight type so that the 
177187            # model is loaded with the correct precision. Subsequently we will 
178188            # use FSDP2's MixedPrecisionPolicy to set the precision for the 
0 commit comments