2222from torchvision .transforms .functional import crop
2323from transformers import CLIPTextModel , CLIPTextModelWithProjection , AutoTokenizer
2424from transformers .trainer_pt_utils import get_module_class_from_name
25- from viztracer import VizTracer
25+ # from viztracer import VizTracer
2626
2727from torch ._dispatch .python import suspend_functionalization
2828from torch ._subclasses .functional_tensor import disable_functional_mode
@@ -145,7 +145,7 @@ def wrap_module(
145145
146146def add_checkpoints (model ):
147147 remat_classes = [get_module_class_from_name (model , "BasicTransformerBlock" )]
148- import pdb ; pdb .set_trace ()
148+ # import pdb; pdb.set_trace()
149149 def maybe_checkpoint (mod ):
150150 if isinstance (mod , tuple (remat_classes )):
151151 return checkpoint_module (mod )
@@ -172,6 +172,7 @@ def __init__(
172172 self .mesh = xs .get_global_mesh ()
173173 self .dataloader = iter (dataloader )
174174 self .global_step = 0
175+ # self.step_fn_compiled = torch.compile(self.step_fn, backend="openxla")
175176
176177 def run_optimizer (self ):
177178 self .optimizer .step ()
@@ -184,7 +185,6 @@ def start_training(self):
184185 assert measure_start_step < self .args .max_train_steps
185186 total_time = 0
186187 last_time = time .time ()
187- tracer = None
188188 for step in range (0 , self .args .max_train_steps ):
189189 print ("step: " , step )
190190 start_time = time .time ()
@@ -193,13 +193,9 @@ def start_training(self):
193193 if step == measure_start_step and PROFILE_DIR is not None :
194194 xm .wait_device_ops ()
195195 xp .trace_detached (f"localhost:{ PORT } " , PROFILE_DIR , duration_ms = args .profile_duration )
196- if step == 15 :
197- tracer = VizTracer ()
198- else :
199- tracer = None
196+
200197 with suspend_functionalization (), disable_functional_mode ():
201198 loss = self .step_fn (
202- tracer ,
203199 batch ["model_input" ],
204200 batch ["prompt_embeds" ],
205201 batch ["pooled_prompt_embeds" ],
@@ -229,107 +225,91 @@ def print_loss_closure(step, loss):
229225
230226 def step_fn (
231227 self ,
232- tracer ,
233228 model_input ,
234229 prompt_embeds ,
235230 pooled_prompt_embeds ,
236231 original_sizes ,
237232 crop_top_lefts
238233 ):
239- # with VizTracer(output_file="forward.json") as tracer:
240- start_time = time .time ()
241- if tracer is not None :
242- tracer .start ()
243- self .optimizer .zero_grad ()
244- noise = torch .randn_like (model_input ).to (self .device , dtype = self .weight_dtype )
245- bsz = model_input .shape [0 ]
246- timesteps = torch .randint (
247- 0 ,
248- self .noise_scheduler .config .num_train_timesteps ,
249- (bsz ,),
250- device = model_input .device ,
251- )
252- timesteps = timesteps .long ()
253- noisy_latents = self .noise_scheduler .add_noise (model_input , noise , timesteps )
254- noisy_latents = noisy_latents .to (self .device , dtype = self .weight_dtype )
255- # time ids
256- def compute_time_ids (original_size , crops_coords_top_left ):
257- # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
258- target_size = torch .tensor ([self .args .resolution , self .args .resolution ]).to (self .device )
259- add_time_ids = torch .unsqueeze (torch .cat ([original_size , crops_coords_top_left , target_size ], axis = 0 ), dim = 0 )
260- return add_time_ids
261-
262- add_time_ids = torch .cat (
263- [compute_time_ids (s , c ) for s , c in zip (original_sizes , crop_top_lefts )]
264- )
265- # Predict the noise residual
266- unet_added_conditions = {"time_ids" : add_time_ids }
267- unet_added_conditions .update ({"text_embeds" : pooled_prompt_embeds })
268- # breakpoint()
269- model_pred = self .unet (
270- noisy_latents ,
271- timesteps ,
272- prompt_embeds ,
273- added_cond_kwargs = unet_added_conditions ,
274- return_dict = False ,
275- )[0 ]
276- if self .args .prediction_type is not None :
277- # set prediction_type of scheduler if defined
278- self .noise_scheduler .register_to_config (prediction_type = self .args .prediction_type )
279- if self .noise_scheduler .config .prediction_type == "epsilon" :
280- target = noise
281- elif self .noise_scheduler .config .prediction_type == "v_prediction" :
282- target = self .noise_scheduler .get_velocity (model_input , noise , timesteps )
283- elif self .noise_scheduler .config .prediction_type == "sample" :
284- # We set the target to latents here, but the model_pred will return the noise sample prediction.
285- target = model_input
286- # We will have to subtract the noise residual from the prediction to get the target sample.
287- model_pred = model_pred - noise
288- else :
289- raise ValueError (f"Unknown prediction type { self .noise_scheduler .config .prediction_type } " )
290-
291- if tracer :
292- tracer .stop ()
293- tracer .save (output_file = "forward.json" )
294- print (f"forward_time = { time .time ()- start_time } " )
295234 start_time = time .time ()
296- # with VizTracer(output_file="backward.json") as tracer:
297-
298- if tracer :
299- tracer .start ()
300- if self .args .snr_gamma is None :
301- loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
302- else :
303- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
304- # Since we predict the noise instead of x_0, the original formulation is slightly changed.
305- # This is discussed in Section 4.2 of the same paper.
306- snr = compute_snr (self .noise_scheduler , timesteps )
307- mse_loss_weights = torch .stack ([snr , self .args .snr_gamma * torch .ones_like (timesteps )], dim = 1 ).min (
308- dim = 1
235+ with xp .Trace ("optimizer_zero_grad" ):
236+ self .optimizer .zero_grad (True )
237+ with xp .Trace ("forward" ):
238+ noise = torch .randn_like (model_input ).to (self .device , dtype = self .weight_dtype )
239+ bsz = model_input .shape [0 ]
240+ timesteps = torch .randint (
241+ 0 ,
242+ self .noise_scheduler .config .num_train_timesteps ,
243+ (bsz ,),
244+ device = model_input .device ,
245+ )
246+ timesteps = timesteps .long ()
247+ noisy_latents = self .noise_scheduler .add_noise (model_input , noise , timesteps )
248+ noisy_latents = noisy_latents .to (self .device , dtype = self .weight_dtype )
249+ # time ids
250+ def compute_time_ids (original_size , crops_coords_top_left ):
251+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
252+ target_size = torch .tensor ([self .args .resolution , self .args .resolution ]).to (self .device )
253+ add_time_ids = torch .unsqueeze (torch .cat ([original_size , crops_coords_top_left , target_size ], axis = 0 ), dim = 0 )
254+ return add_time_ids
255+
256+ add_time_ids = torch .cat (
257+ [compute_time_ids (s , c ) for s , c in zip (original_sizes , crop_top_lefts )]
258+ )
259+ # Predict the noise residual
260+ unet_added_conditions = {"time_ids" : add_time_ids }
261+ unet_added_conditions .update ({"text_embeds" : pooled_prompt_embeds })
262+ # breakpoint()
263+ model_pred = self .unet (
264+ noisy_latents ,
265+ timesteps ,
266+ prompt_embeds ,
267+ added_cond_kwargs = unet_added_conditions ,
268+ return_dict = False ,
309269 )[0 ]
270+ if self .args .prediction_type is not None :
271+ # set prediction_type of scheduler if defined
272+ self .noise_scheduler .register_to_config (prediction_type = self .args .prediction_type )
310273 if self .noise_scheduler .config .prediction_type == "epsilon" :
311- mse_loss_weights = mse_loss_weights / snr
274+ target = noise
312275 elif self .noise_scheduler .config .prediction_type == "v_prediction" :
313- mse_loss_weights = mse_loss_weights / (snr + 1 )
314-
315- loss = F .mse_loss (model_pred .float (), target .float (), reduction = "none" )
316- loss = loss .mean (dim = list (range (1 , len (loss .shape )))) * mse_loss_weights
317- loss = loss .mean ()
318- loss .backward ()
319- if tracer :
320- tracer .stop ()
321- tracer .save (output_file = "backward.json" )
276+ target = self .noise_scheduler .get_velocity (model_input , noise , timesteps )
277+ elif self .noise_scheduler .config .prediction_type == "sample" :
278+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
279+ target = model_input
280+ # We will have to subtract the noise residual from the prediction to get the target sample.
281+ model_pred = model_pred - noise
282+ else :
283+ raise ValueError (f"Unknown prediction type { self .noise_scheduler .config .prediction_type } " )
284+
285+ print (f"forward_time = { time .time ()- start_time } " )
286+ start_time = time .time ()
287+ with xp .Trace ("backward" ):
288+ if self .args .snr_gamma is None :
289+ loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
290+ else :
291+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
292+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
293+ # This is discussed in Section 4.2 of the same paper.
294+ snr = compute_snr (self .noise_scheduler , timesteps )
295+ mse_loss_weights = torch .stack ([snr , self .args .snr_gamma * torch .ones_like (timesteps )], dim = 1 ).min (
296+ dim = 1
297+ )[0 ]
298+ if self .noise_scheduler .config .prediction_type == "epsilon" :
299+ mse_loss_weights = mse_loss_weights / snr
300+ elif self .noise_scheduler .config .prediction_type == "v_prediction" :
301+ mse_loss_weights = mse_loss_weights / (snr + 1 )
302+
303+ loss = F .mse_loss (model_pred .float (), target .float (), reduction = "none" )
304+ loss = loss .mean (dim = list (range (1 , len (loss .shape )))) * mse_loss_weights
305+ loss = loss .mean ()
306+ loss .backward ()
322307 print (f"backward time = { time .time ()- start_time } " )
323308 start_time = time .time ()
324- # with xp.Trace("optimizer_step"):
325- if tracer :
326- tracer .start ()
327- self .run_optimizer ()
328- if tracer :
329- tracer .stop ()
330- tracer .save (output_file = "optimizer.json" )
309+ with xp .Trace ("optimizer_step" ):
310+ self .run_optimizer ()
331311 print (f"optimizer step = { time .time ()- start_time } " )
332- return loss
312+ return model_pred
333313
334314
335315def parse_args ():
@@ -567,7 +547,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
567547 prompt_embeds = prompt_embeds .view (bs_embed , seq_len , - 1 )
568548 prompt_embeds_list .append (prompt_embeds )
569549
550+
570551 prompt_embeds = torch .concat (prompt_embeds_list , dim = - 1 ).to (dtype = dtype )
552+ print (prompt_embeds .shape )
553+ p3d = (0 ,0 , 0 , 128 - 77 )
554+ prompt_embeds = F .pad (prompt_embeds , p3d , "constant" , 0 )
555+ print (prompt_embeds .shape )
571556 pooled_prompt_embeds = pooled_prompt_embeds .view (bs_embed , - 1 ).to (dtype = dtype )
572557 return {"prompt_embeds" : prompt_embeds , "pooled_prompt_embeds" : pooled_prompt_embeds }
573558
@@ -580,7 +565,8 @@ def compute_vae_encodings(batch, vae):
580565 with torch .no_grad ():
581566 model_input = vae .encode (pixel_values ).latent_dist .sample ()
582567 model_input = model_input * vae .config .scaling_factor
583- return {"model_input" : model_input }
568+ xm .mark_step ()
569+ return {"model_input" : model_input .cpu ()}
584570
585571
586572def load_dataset (args ):
@@ -770,16 +756,20 @@ def preprocess_train(examples):
770756 )
771757 compute_vae_encodings_fn = functools .partial (compute_vae_encodings , vae = vae )
772758 from datasets .fingerprint import Hasher
773-
759+ # import pdb; pdb.set_trace()
760+ old_batch_size = args .train_batch_size
761+ args .train_batch_size = 21
774762 new_fingerprint = Hasher .hash (args )
763+ args .train_batch_size = 64
775764 new_fingerprint_for_vae = Hasher .hash ((args .pretrained_model_name_or_path , args ))
765+ args .train_batch_size = old_batch_size
776766 train_dataset_with_embeddings = train_dataset .map (
777- compute_embeddings_fn , batched = True , new_fingerprint = new_fingerprint
767+ compute_embeddings_fn , batched = True , batch_size = 50 , new_fingerprint = new_fingerprint
778768 )
779769 train_dataset_with_vae = train_dataset .map (
780770 compute_vae_encodings_fn ,
781771 batched = True ,
782- batch_size = args . train_batch_size ,
772+ batch_size = 50 ,
783773 new_fingerprint = new_fingerprint_for_vae ,
784774 )
785775 precomputed_dataset = concatenate_datasets (
@@ -794,14 +784,6 @@ def collate_fn(examples):
794784 crop_top_lefts = torch .stack ([torch .tensor (example ["crop_top_lefts" ]) for example in examples ])
795785 prompt_embeds = torch .stack ([torch .tensor (example ["prompt_embeds" ]) for example in examples ]).to (dtype = weight_dtype )
796786 pooled_prompt_embeds = torch .stack ([torch .tensor (example ["pooled_prompt_embeds" ]) for example in examples ]).to (dtype = weight_dtype )
797- # print("model_input.shape: ", model_input.shape)
798- # print("model_input.dtype: ", model_input.dtype)
799- # print("prompt_embeds.shape: ", prompt_embeds.shape)
800- # print("prompt_embeds.dtype: ", prompt_embeds.dtype)
801- # print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape)
802- # print("pooled_prompt_embeds.dtype: ", pooled_prompt_embeds.dtype)
803- # print("original_sizes.shape: ", original_sizes.shape)
804- # print("crop_top_lefts.shape: ", crop_top_lefts.shape)
805787 return {
806788 "model_input" : model_input ,
807789 "prompt_embeds" : prompt_embeds ,
@@ -846,7 +828,7 @@ def collate_fn(examples):
846828 )
847829 print (f" Total optimization steps = { args .max_train_steps } " )
848830
849- unet = add_checkpoints (unet )
831+ # unet = add_checkpoints(unet)
850832
851833 trainer = TrainSD (
852834 weight_dtype = weight_dtype ,
0 commit comments