@@ -1028,11 +1028,6 @@ def encode_prompt(
10281028):
10291029    prompt  =  [prompt ] if  isinstance (prompt , str ) else  prompt 
10301030
1031-     if  hasattr (text_encoders [0 ], "module" ):
1032-         dtype  =  text_encoders [0 ].module .dtype 
1033-     else :
1034-         dtype  =  text_encoders [0 ].dtype 
1035- 
10361031    pooled_prompt_embeds_1  =  _encode_prompt_with_clip (
10371032        text_encoder = text_encoders [0 ],
10381033        tokenizer = tokenizers [0 ],
@@ -1179,21 +1174,50 @@ def main(args):
11791174                exist_ok = True ,
11801175            ).repo_id 
11811176
1182-     # Load the tokenizer  
1183-     tokenizer  =  AutoTokenizer .from_pretrained (
1177+     # Load the tokenizers  
1178+     tokenizer_one  =  CLIPTokenizer .from_pretrained (
11841179        args .pretrained_model_name_or_path ,
11851180        subfolder = "tokenizer" ,
11861181        revision = args .revision ,
11871182    )
1183+     tokenizer_two  =  CLIPTokenizer .from_pretrained (
1184+         args .pretrained_model_name_or_path ,
1185+         subfolder = "tokenizer_2" ,
1186+         revision = args .revision ,
1187+     )
1188+     tokenizer_three  =  T5TokenizerFast .from_pretrained (
1189+         args .pretrained_model_name_or_path ,
1190+         subfolder = "tokenizer_3" ,
1191+         revision = args .revision ,
1192+     )
1193+ 
1194+     tokenizer_four  =  PreTrainedTokenizerFast .from_pretrained (
1195+         args .pretrained_model_name_or_path ,
1196+         subfolder = "tokenizer_4" ,
1197+         revision = args .revision ,
1198+     )
1199+ 
1200+     # import correct text encoder classes 
1201+     text_encoder_cls_one  =  import_model_class_from_model_name_or_path (
1202+         args .pretrained_model_name_or_path , args .revision 
1203+     )
1204+     text_encoder_cls_two  =  import_model_class_from_model_name_or_path (
1205+         args .pretrained_model_name_or_path , args .revision , subfolder = "text_encoder_2" 
1206+     )
1207+     text_encoder_cls_three  =  import_model_class_from_model_name_or_path (
1208+         args .pretrained_model_name_or_path , args .revision , subfolder = "text_encoder_3" 
1209+     )
1210+     text_encoder_cls_four  =  import_model_class_from_model_name_or_path (
1211+         args .pretrained_model_name_or_path , args .revision , subfolder = "text_encoder_4" 
1212+     )
11881213
11891214    # Load scheduler and models 
11901215    noise_scheduler  =  FlowMatchEulerDiscreteScheduler .from_pretrained (
11911216        args .pretrained_model_name_or_path , subfolder = "scheduler" , revision = args .revision 
11921217    )
11931218    noise_scheduler_copy  =  copy .deepcopy (noise_scheduler )
1194-     text_encoder  =  Gemma2Model .from_pretrained (
1195-         args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision , variant = args .variant 
1196-     )
1219+     text_encoder_one , text_encoder_two , text_encoder_three , text_encoder_four  =  load_text_encoders (text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three , text_encoder_cls_four )
1220+ 
11971221    vae  =  AutoencoderKL .from_pretrained (
11981222        args .pretrained_model_name_or_path ,
11991223        subfolder = "vae" ,
@@ -1207,7 +1231,10 @@ def main(args):
12071231    # We only train the additional adapter LoRA layers 
12081232    transformer .requires_grad_ (False )
12091233    vae .requires_grad_ (False )
1210-     text_encoder .requires_grad_ (False )
1234+     text_encoder_one .requires_grad_ (False )
1235+     text_encoder_two .requires_grad_ (False )
1236+     text_encoder_three .requires_grad_ (False )
1237+     text_encoder_four .requires_grad_ (False )
12111238
12121239    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision 
12131240    # as these weights are only used for inference, keeping weights in full precision is not required. 
@@ -1226,17 +1253,10 @@ def main(args):
12261253    # keep VAE in FP32 to ensure numerical stability. 
12271254    vae .to (dtype = torch .float32 )
12281255    transformer .to (accelerator .device , dtype = weight_dtype )
1229-     # because Gemma2 is particularly suited for bfloat16. 
1230-     text_encoder .to (dtype = torch .bfloat16 )
1231- 
1232-     # Initialize a text encoding pipeline and keep it to CPU for now. 
1233-     text_encoding_pipeline  =  HiDreamImagePipeline .from_pretrained (
1234-         args .pretrained_model_name_or_path ,
1235-         vae = None ,
1236-         transformer = None ,
1237-         text_encoder = text_encoder ,
1238-         tokenizer = tokenizer ,
1239-     )
1256+     text_encoder_one .to (accelerator .device , dtype = weight_dtype )
1257+     text_encoder_two .to (accelerator .device , dtype = weight_dtype )
1258+     text_encoder_three .to (accelerator .device , dtype = weight_dtype )
1259+     text_encoder_four .to (accelerator .device , dtype = weight_dtype )
12401260
12411261    if  args .gradient_checkpointing :
12421262        transformer .enable_gradient_checkpointing ()
@@ -1417,47 +1437,45 @@ def load_model_hook(models, input_dir):
14171437        num_workers = args .dataloader_num_workers ,
14181438    )
14191439
1420-     def  compute_text_embeddings (prompt , text_encoding_pipeline ):
1421-         text_encoding_pipeline  =  text_encoding_pipeline .to (accelerator .device )
1440+     tokenizers  =  [tokenizer_one , tokenizer_two , tokenizer_three , tokenizer_four ]
1441+     text_encoders  =  [text_encoder_one , text_encoder_two , text_encoder_three , text_encoder_four ]
1442+     def  compute_text_embeddings (prompt , text_encoders , tokenizers ):
14221443        with  torch .no_grad ():
1423-             prompt_embeds , prompt_attention_mask , _ , _  =  text_encoding_pipeline .encode_prompt (
1424-                 prompt ,
1425-                 max_sequence_length = args .max_sequence_length ,
1426-                 system_prompt = args .system_prompt ,
1444+             prompt_embeds , pooled_prompt_embeds  =  encode_prompt (
1445+                 text_encoders , tokenizers , prompt , args .max_sequence_length 
14271446            )
1428-         if  args .offload :
1429-             text_encoding_pipeline  =  text_encoding_pipeline .to ("cpu" )
1430-         prompt_embeds  =  prompt_embeds .to (transformer .dtype )
1431-         return  prompt_embeds , prompt_attention_mask 
1447+             prompt_embeds  =  prompt_embeds .to (accelerator .device )
1448+             pooled_prompt_embeds  =  pooled_prompt_embeds .to (accelerator .device )
1449+         return  prompt_embeds , pooled_prompt_embeds 
14321450
14331451    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT 
14341452    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid 
14351453    # the redundant encoding. 
14361454    if  not  train_dataset .custom_instance_prompts :
1437-         instance_prompt_hidden_states , instance_prompt_attention_mask  =  compute_text_embeddings (
1438-             args .instance_prompt , text_encoding_pipeline 
1455+         instance_prompt_hidden_states , instance_pooled_prompt_embeds ,  =  compute_text_embeddings (
1456+             args .instance_prompt , text_encoders ,  tokenizers 
14391457        )
14401458
14411459    # Handle class prompt for prior-preservation. 
14421460    if  args .with_prior_preservation :
1443-         class_prompt_hidden_states , class_prompt_attention_mask  =  compute_text_embeddings (
1444-             args .class_prompt , text_encoding_pipeline 
1461+         class_prompt_hidden_states , class_pooled_prompt_embeds ,  =  compute_text_embeddings (
1462+             args .class_prompt , text_encoders ,  tokenizers 
14451463        )
14461464
14471465    # Clear the memory here 
14481466    if  not  train_dataset .custom_instance_prompts :
1449-         del  text_encoder ,  tokenizer 
1467+         del  text_encoder_one ,  text_encoder_two ,  text_encoder_three ,  text_encoder_four ,  tokenizer_one ,  tokenizer_two , tokenizer_three ,  tokenizer_four 
14501468        free_memory ()
14511469
14521470    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), 
14531471    # pack the statically computed variables appropriately here. This is so that we don't 
14541472    # have to pass them to the dataloader. 
14551473    if  not  train_dataset .custom_instance_prompts :
14561474        prompt_embeds  =  instance_prompt_hidden_states 
1457-         prompt_attention_mask  =  instance_prompt_attention_mask 
1475+         pooled_prompt_embeds  =  instance_pooled_prompt_embeds 
14581476        if  args .with_prior_preservation :
14591477            prompt_embeds  =  torch .cat ([prompt_embeds , class_prompt_hidden_states ], dim = 0 )
1460-             prompt_attention_mask  =  torch .cat ([prompt_attention_mask ,  class_prompt_attention_mask ], dim = 0 )
1478+             pooled_prompt_embeds  =  torch .cat ([pooled_prompt_embeds ,  class_pooled_prompt_embeds ], dim = 0 )
14611479
14621480    vae_config_scaling_factor  =  vae .config .scaling_factor 
14631481    vae_config_shift_factor  =  vae .config .shift_factor 
@@ -1506,7 +1524,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15061524    # We need to initialize the trackers we use, and also store our configuration. 
15071525    # The trackers initializes automatically on the main process. 
15081526    if  accelerator .is_main_process :
1509-         tracker_name  =  "dreambooth-lumina2 -lora" 
1527+         tracker_name  =  "dreambooth-hidream -lora" 
15101528        accelerator .init_trackers (tracker_name , config = vars (args ))
15111529
15121530    # Train! 
@@ -1580,7 +1598,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15801598            with  accelerator .accumulate (models_to_accumulate ):
15811599                # encode batch prompts when custom prompts are provided for each image - 
15821600                if  train_dataset .custom_instance_prompts :
1583-                     prompt_embeds , prompt_attention_mask  =  compute_text_embeddings (prompts , text_encoding_pipeline )
1601+                     prompt_embeds , pooled_prompt_embeds  =  compute_text_embeddings (prompts , text_encoders ,  tokenizers )
15841602
15851603                # Convert images to latent space 
15861604                if  args .cache_latents :
@@ -1594,6 +1612,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15941612                model_input  =  (model_input  -  vae_config_shift_factor ) *  vae_config_scaling_factor 
15951613                model_input  =  model_input .to (dtype = weight_dtype )
15961614
1615+                 if  model_input .shape [- 2 ] !=  model_input .shape [- 1 ]:
1616+                     B , C , H , W  =  model_input .shape 
1617+                     pH , pW  =  H  //  transformer .config .patch_size , W  //  transformer .config .patch_size 
1618+ 
1619+                     img_sizes  =  torch .tensor ([pH , pW ], dtype = torch .int64 ).reshape (- 1 )
1620+                     img_ids  =  torch .zeros (pH , pW , 3 )
1621+                     img_ids [..., 1 ] =  img_ids [..., 1 ] +  torch .arange (pH )[:, None ]
1622+                     img_ids [..., 2 ] =  img_ids [..., 2 ] +  torch .arange (pW )[None , :]
1623+                     img_ids  =  img_ids .reshape (pH  *  pW , - 1 )
1624+                     img_ids_pad  =  torch .zeros (self .transformer .max_seq , 3 )
1625+                     img_ids_pad [: pH  *  pW , :] =  img_ids 
1626+ 
1627+                     img_sizes  =  img_sizes .unsqueeze (0 ).to (model_input .device )
1628+                     img_ids  =  img_ids_pad .unsqueeze (0 ).to (model_input .device )
1629+ 
1630+                 else :
1631+                     img_sizes  =  img_ids  =  None 
1632+ 
15971633                # Sample noise that we'll add to the latents 
15981634                noise  =  torch .randn_like (model_input )
15991635                bsz  =  model_input .shape [0 ]
@@ -1612,22 +1648,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16121648
16131649                # Add noise according to flow matching. 
16141650                # zt = (1 - texp) * x + texp * z1 
1615-                 # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `model_input` 
16161651                sigmas  =  get_sigmas (timesteps , n_dim = model_input .ndim , dtype = model_input .dtype )
1617-                 noisy_model_input  =  (1.0  -  sigmas ) *  noise  +  sigmas  *  model_input 
1652+                 noisy_model_input  =  (1.0  -  sigmas ) *  model_input  +  sigmas  *  noise 
16181653
16191654                # Predict the noise residual 
1620-                 # scale the timesteps (reversal not needed as we used a reverse lerp above already) 
1621-                 timesteps  =  timesteps  /  noise_scheduler .config .num_train_timesteps 
16221655                model_pred  =  transformer (
16231656                    hidden_states = noisy_model_input ,
16241657                    encoder_hidden_states = prompt_embeds .repeat (len (prompts ), 1 , 1 )
16251658                    if  not  train_dataset .custom_instance_prompts 
16261659                    else  prompt_embeds ,
1627-                     encoder_attention_mask = prompt_attention_mask .repeat (len (prompts ), 1 )
1660+                     pooled_embeds = pooled_prompt_embeds .repeat (len (prompts ), 1 )
16281661                    if  not  train_dataset .custom_instance_prompts 
1629-                     else  prompt_attention_mask ,
1662+                     else  pooled_prompt_embeds ,
16301663                    timestep = timesteps ,
1664+                     img_sizes = img_sizes ,
1665+                     img_ids = img_ids ,
16311666                    return_dict = False ,
16321667                )[0 ]
16331668
0 commit comments