@@ -122,52 +122,6 @@ def prepare_dataset(self) -> None:
122122 pin_memory = self .args .pin_memory ,
123123 )
124124
125- def _get_load_components_kwargs (self ) -> Dict [str , Any ]:
126- load_component_kwargs = {
127- "text_encoder_dtype" : self .args .text_encoder_dtype ,
128- "text_encoder_2_dtype" : self .args .text_encoder_2_dtype ,
129- "text_encoder_3_dtype" : self .args .text_encoder_3_dtype ,
130- "transformer_dtype" : self .args .transformer_dtype ,
131- "vae_dtype" : self .args .vae_dtype ,
132- "shift" : self .args .flow_shift ,
133- "revision" : self .args .revision ,
134- "cache_dir" : self .args .cache_dir ,
135- }
136- if self .args .pretrained_model_name_or_path is not None :
137- load_component_kwargs ["model_id" ] = self .args .pretrained_model_name_or_path
138- return load_component_kwargs
139-
140- def _set_components (self , components : Dict [str , Any ]) -> None :
141- # Set models
142- self .tokenizer = components .get ("tokenizer" , self .tokenizer )
143- self .tokenizer_2 = components .get ("tokenizer_2" , self .tokenizer_2 )
144- self .tokenizer_3 = components .get ("tokenizer_3" , self .tokenizer_3 )
145- self .text_encoder = components .get ("text_encoder" , self .text_encoder )
146- self .text_encoder_2 = components .get ("text_encoder_2" , self .text_encoder_2 )
147- self .text_encoder_3 = components .get ("text_encoder_3" , self .text_encoder_3 )
148- self .transformer = components .get ("transformer" , self .transformer )
149- self .unet = components .get ("unet" , self .unet )
150- self .vae = components .get ("vae" , self .vae )
151- self .scheduler = components .get ("scheduler" , self .scheduler )
152-
153- # Set configs
154- self .transformer_config = self .transformer .config if self .transformer is not None else self .transformer_config
155- self .vae_config = self .vae .config if self .vae is not None else self .vae_config
156-
157- def _delete_components (self ) -> None :
158- self .tokenizer = None
159- self .tokenizer_2 = None
160- self .tokenizer_3 = None
161- self .text_encoder = None
162- self .text_encoder_2 = None
163- self .text_encoder_3 = None
164- self .transformer = None
165- self .unet = None
166- self .vae = None
167- self .scheduler = None
168- free_memory ()
169- torch .cuda .synchronize (self .state .accelerator .device )
170-
171125 def prepare_models (self ) -> None :
172126 logger .info ("Initializing models" )
173127
@@ -1109,6 +1063,52 @@ def _move_components_to_device(self):
11091063 if self .vae is not None :
11101064 self .vae = self .vae .to (self .state .accelerator .device )
11111065
1066+ def _get_load_components_kwargs (self ) -> Dict [str , Any ]:
1067+ load_component_kwargs = {
1068+ "text_encoder_dtype" : self .args .text_encoder_dtype ,
1069+ "text_encoder_2_dtype" : self .args .text_encoder_2_dtype ,
1070+ "text_encoder_3_dtype" : self .args .text_encoder_3_dtype ,
1071+ "transformer_dtype" : self .args .transformer_dtype ,
1072+ "vae_dtype" : self .args .vae_dtype ,
1073+ "shift" : self .args .flow_shift ,
1074+ "revision" : self .args .revision ,
1075+ "cache_dir" : self .args .cache_dir ,
1076+ }
1077+ if self .args .pretrained_model_name_or_path is not None :
1078+ load_component_kwargs ["model_id" ] = self .args .pretrained_model_name_or_path
1079+ return load_component_kwargs
1080+
1081+ def _set_components (self , components : Dict [str , Any ]) -> None :
1082+ # Set models
1083+ self .tokenizer = components .get ("tokenizer" , self .tokenizer )
1084+ self .tokenizer_2 = components .get ("tokenizer_2" , self .tokenizer_2 )
1085+ self .tokenizer_3 = components .get ("tokenizer_3" , self .tokenizer_3 )
1086+ self .text_encoder = components .get ("text_encoder" , self .text_encoder )
1087+ self .text_encoder_2 = components .get ("text_encoder_2" , self .text_encoder_2 )
1088+ self .text_encoder_3 = components .get ("text_encoder_3" , self .text_encoder_3 )
1089+ self .transformer = components .get ("transformer" , self .transformer )
1090+ self .unet = components .get ("unet" , self .unet )
1091+ self .vae = components .get ("vae" , self .vae )
1092+ self .scheduler = components .get ("scheduler" , self .scheduler )
1093+
1094+ # Set configs
1095+ self .transformer_config = self .transformer .config if self .transformer is not None else self .transformer_config
1096+ self .vae_config = self .vae .config if self .vae is not None else self .vae_config
1097+
1098+ def _delete_components (self ) -> None :
1099+ self .tokenizer = None
1100+ self .tokenizer_2 = None
1101+ self .tokenizer_3 = None
1102+ self .text_encoder = None
1103+ self .text_encoder_2 = None
1104+ self .text_encoder_3 = None
1105+ self .transformer = None
1106+ self .unet = None
1107+ self .vae = None
1108+ self .scheduler = None
1109+ free_memory ()
1110+ torch .cuda .synchronize (self .state .accelerator .device )
1111+
11121112 def _get_training_dtype (self , accelerator ) -> torch .dtype :
11131113 weight_dtype = torch .float32
11141114 if accelerator .state .deepspeed_plugin :
0 commit comments