@@ -152,7 +152,6 @@ def start_training(self):
152152 for step in range (0 , self .args .max_train_steps ):
153153 print ("step: " , step )
154154 batch = next (self .dataloader )
155- breakpoint ()
156155 if step == measure_start_step :
157156 if PROFILE_DIR is not None :
158157 xm .wait_device_ops ()
@@ -164,22 +163,22 @@ def start_training(self):
164163 def print_loss_closure (step , loss ):
165164 print (f"Step: { step } , Loss: { loss } " )
166165
167- # if self.args.print_loss:
168- # xm.add_step_closure(
169- # print_loss_closure,
170- # args=(
171- # self.global_step,
172- # loss,
173- # ),
174- # )
175- # xm.mark_step()
176- # if not dataloader_exception:
177- # xm.wait_device_ops()
178- # total_time = time.time() - last_time
179- # print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
180- # else:
181- # print("dataloader exception happen, skip result")
182- # return
166+ if self .args .print_loss :
167+ xm .add_step_closure (
168+ print_loss_closure ,
169+ args = (
170+ self .global_step ,
171+ loss ,
172+ ),
173+ )
174+ xm .mark_step ()
175+ if not dataloader_exception :
176+ xm .wait_device_ops ()
177+ total_time = time .time () - last_time
178+ print (f"Average step time: { total_time / (self .args .max_train_steps - measure_start_step )} " )
179+ else :
180+ print ("dataloader exception happen, skip result" )
181+ return
183182 def get_sigmas (self , timesteps , n_dim = 4 , dtype = torch .float32 ):
184183 sigmas = self .noise_scheduler_copy .sigmas .to (device = self .device , dtype = dtype )
185184 schedule_timesteps = self .noise_scheduler_copy .timesteps .to (self .device )
@@ -307,6 +306,24 @@ def parse_args():
307306 choices = ["sigma_sqrt" , "logit_normal" , "mode" , "cosmap" , "none" ],
308307 help = ('We default to the "none" weighting scheme for uniform sampling and uniform loss' ),
309308 )
309+ parser .add_argument (
310+ "--logit_mean" , type = float , default = 0.0 , help = "mean to use when using the `'logit_normal'` weighting scheme."
311+ )
312+ parser .add_argument (
313+ "--logit_std" , type = float , default = 1.0 , help = "std to use when using the `'logit_normal'` weighting scheme."
314+ )
315+ parser .add_argument (
316+ "--mode_scale" ,
317+ type = float ,
318+ default = 1.29 ,
319+ help = "Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`." ,
320+ )
321+ parser .add_argument (
322+ "--guidance_scale" ,
323+ type = float ,
324+ default = 3.5 ,
325+ help = "the FLUX.1 dev variant is a guidance distilled model" ,
326+ )
310327 parser .add_argument (
311328 "--revision" ,
312329 type = str ,
@@ -793,7 +810,7 @@ def preprocess_train(examples):
793810 compute_embeddings_fn , batched = True , new_fingerprint = new_fingerprint
794811 )
795812 train_dataset_with_tensors = train_dataset .map (
796- pixels_to_tensors_fn , batched = True , new_fingerprint = new_fingerprint_two , batch_size = 64
813+ pixels_to_tensors_fn , batched = True , new_fingerprint = new_fingerprint_two , batch_size = 256
797814 )
798815 precomputed_dataset = concatenate_datasets (
799816 [train_dataset_with_embeddings , train_dataset_with_tensors .remove_columns (["text" , "image" ])], axis = 1
0 commit comments