Skip to content

Commit de1654a

Browse files
committed
fix import and tokenizer 4, text encoder 4 loading
1 parent 31aa0a2 commit de1654a

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from torchvision import transforms
4343
from torchvision.transforms.functional import crop
4444
from tqdm.auto import tqdm
45-
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast
45+
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, PreTrainedTokenizerFast, LlamaForCausalLM
4646

4747
import diffusers
4848
from diffusers import (
@@ -146,7 +146,7 @@ def save_model_card(
146146
model_card = populate_model_card(model_card, tags=tags)
147147
model_card.save(os.path.join(repo_folder, "README.md"))
148148

149-
def load_text_encoders(class_one, class_two, class_three, class_four):
149+
def load_text_encoders(class_one, class_two, class_three):
150150
text_encoder_one = class_one.from_pretrained(
151151
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
152152
)
@@ -156,9 +156,11 @@ def load_text_encoders(class_one, class_two, class_three, class_four):
156156
text_encoder_three = class_three.from_pretrained(
157157
args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant
158158
)
159-
text_encoder_four = class_four.from_pretrained(
160-
args.pretrained_model_name_or_path, subfolder="text_encoder_4", revision=args.revision, variant=args.variant
161-
)
159+
text_encoder_four = LlamaForCausalLM.from_pretrained(
160+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
161+
output_hidden_states=True,
162+
output_attentions=True,
163+
torch_dtype=torch.bfloat16,)
162164
return text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four
163165

164166
def log_validation(
@@ -211,18 +213,14 @@ def import_model_class_from_model_name_or_path(
211213
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
212214
)
213215
model_class = text_encoder_config.architectures[0]
214-
if model_class == "CLIPTextModelWithProjection":
215-
from transformers import CLIPTextModel
216+
if model_class == "CLIPTextModelWithProjection" or model_class == "CLIPTextModel":
217+
from transformers import CLIPTextModelWithProjection
216218

217-
return CLIPTextModel
219+
return CLIPTextModelWithProjection
218220
elif model_class == "T5EncoderModel":
219221
from transformers import T5EncoderModel
220222

221223
return T5EncoderModel
222-
elif model_class == "LlamaForCausalLM":
223-
from transformers import LlamaForCausalLM
224-
225-
return LlamaForCausalLM
226224
else:
227225
raise ValueError(f"{model_class} is not supported.")
228226

@@ -1184,8 +1182,7 @@ def main(args):
11841182
)
11851183

11861184
tokenizer_four = PreTrainedTokenizerFast.from_pretrained(
1187-
args.pretrained_model_name_or_path,
1188-
subfolder="tokenizer_4",
1185+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
11891186
revision=args.revision,
11901187
)
11911188

@@ -1199,16 +1196,13 @@ def main(args):
11991196
text_encoder_cls_three = import_model_class_from_model_name_or_path(
12001197
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
12011198
)
1202-
text_encoder_cls_four = import_model_class_from_model_name_or_path(
1203-
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_4"
1204-
)
12051199

12061200
# Load scheduler and models
12071201
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
12081202
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
12091203
)
12101204
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
1211-
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, text_encoder_cls_four)
1205+
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three)
12121206

12131207
vae = AutoencoderKL.from_pretrained(
12141208
args.pretrained_model_name_or_path,
@@ -1740,6 +1734,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17401734
# create pipeline
17411735
pipeline = HiDreamImagePipeline.from_pretrained(
17421736
args.pretrained_model_name_or_path,
1737+
# tokenizer_4=tokenizer_4,
1738+
# text_encoder_4=text_encoder_4,
17431739
transformer=accelerator.unwrap_model(transformer),
17441740
revision=args.revision,
17451741
variant=args.variant,
@@ -1777,6 +1773,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17771773
# Load previous pipeline
17781774
pipeline = HiDreamImagePipeline.from_pretrained(
17791775
args.pretrained_model_name_or_path,
1776+
# tokenizer_4=tokenizer_4,
1777+
# text_encoder_4=text_encoder_4,
17801778
revision=args.revision,
17811779
variant=args.variant,
17821780
torch_dtype=weight_dtype,

0 commit comments

Comments
 (0)