4242from torchvision import transforms
4343from torchvision .transforms .functional import crop
4444from tqdm .auto import tqdm
45- from transformers import CLIPTokenizer , PretrainedConfig , T5TokenizerFast , PreTrainedTokenizerFast
45+ from transformers import CLIPTokenizer , PretrainedConfig , T5TokenizerFast , PreTrainedTokenizerFast , LlamaForCausalLM
4646
4747import diffusers
4848from diffusers import (
@@ -146,7 +146,7 @@ def save_model_card(
146146 model_card = populate_model_card (model_card , tags = tags )
147147 model_card .save (os .path .join (repo_folder , "README.md" ))
148148
149- def load_text_encoders (class_one , class_two , class_three , class_four ):
149+ def load_text_encoders (class_one , class_two , class_three ):
150150 text_encoder_one = class_one .from_pretrained (
151151 args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision , variant = args .variant
152152 )
@@ -156,9 +156,11 @@ def load_text_encoders(class_one, class_two, class_three, class_four):
156156 text_encoder_three = class_three .from_pretrained (
157157 args .pretrained_model_name_or_path , subfolder = "text_encoder_3" , revision = args .revision , variant = args .variant
158158 )
159- text_encoder_four = class_four .from_pretrained (
160- args .pretrained_model_name_or_path , subfolder = "text_encoder_4" , revision = args .revision , variant = args .variant
161- )
159+ text_encoder_four = LlamaForCausalLM .from_pretrained (
160+ "meta-llama/Meta-Llama-3.1-8B-Instruct" ,
161+ output_hidden_states = True ,
162+ output_attentions = True ,
163+ torch_dtype = torch .bfloat16 ,)
162164 return text_encoder_one , text_encoder_two , text_encoder_three , text_encoder_four
163165
164166def log_validation (
@@ -211,18 +213,14 @@ def import_model_class_from_model_name_or_path(
211213 pretrained_model_name_or_path , subfolder = subfolder , revision = revision
212214 )
213215 model_class = text_encoder_config .architectures [0 ]
214- if model_class == "CLIPTextModelWithProjection" :
215- from transformers import CLIPTextModel
216+ if model_class == "CLIPTextModelWithProjection" or model_class == "CLIPTextModel" :
217+ from transformers import CLIPTextModelWithProjection
216218
217- return CLIPTextModel
219+ return CLIPTextModelWithProjection
218220 elif model_class == "T5EncoderModel" :
219221 from transformers import T5EncoderModel
220222
221223 return T5EncoderModel
222- elif model_class == "LlamaForCausalLM" :
223- from transformers import LlamaForCausalLM
224-
225- return LlamaForCausalLM
226224 else :
227225 raise ValueError (f"{ model_class } is not supported." )
228226
@@ -1184,8 +1182,7 @@ def main(args):
11841182 )
11851183
11861184 tokenizer_four = PreTrainedTokenizerFast .from_pretrained (
1187- args .pretrained_model_name_or_path ,
1188- subfolder = "tokenizer_4" ,
1185+ "meta-llama/Meta-Llama-3.1-8B-Instruct" ,
11891186 revision = args .revision ,
11901187 )
11911188
@@ -1199,16 +1196,13 @@ def main(args):
11991196 text_encoder_cls_three = import_model_class_from_model_name_or_path (
12001197 args .pretrained_model_name_or_path , args .revision , subfolder = "text_encoder_3"
12011198 )
1202- text_encoder_cls_four = import_model_class_from_model_name_or_path (
1203- args .pretrained_model_name_or_path , args .revision , subfolder = "text_encoder_4"
1204- )
12051199
12061200 # Load scheduler and models
12071201 noise_scheduler = FlowMatchEulerDiscreteScheduler .from_pretrained (
12081202 args .pretrained_model_name_or_path , subfolder = "scheduler" , revision = args .revision
12091203 )
12101204 noise_scheduler_copy = copy .deepcopy (noise_scheduler )
1211- text_encoder_one , text_encoder_two , text_encoder_three , text_encoder_four = load_text_encoders (text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three , text_encoder_cls_four )
1205+ text_encoder_one , text_encoder_two , text_encoder_three , text_encoder_four = load_text_encoders (text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three )
12121206
12131207 vae = AutoencoderKL .from_pretrained (
12141208 args .pretrained_model_name_or_path ,
@@ -1740,6 +1734,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17401734 # create pipeline
17411735 pipeline = HiDreamImagePipeline .from_pretrained (
17421736 args .pretrained_model_name_or_path ,
1737+ # tokenizer_4=tokenizer_4,
1738+ # text_encoder_4=text_encoder_4,
17431739 transformer = accelerator .unwrap_model (transformer ),
17441740 revision = args .revision ,
17451741 variant = args .variant ,
@@ -1777,6 +1773,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17771773 # Load previous pipeline
17781774 pipeline = HiDreamImagePipeline .from_pretrained (
17791775 args .pretrained_model_name_or_path ,
1776+ # tokenizer_4=tokenizer_4,
1777+ # text_encoder_4=text_encoder_4,
17801778 revision = args .revision ,
17811779 variant = args .variant ,
17821780 torch_dtype = weight_dtype ,
0 commit comments