Skip to content

Commit 0565932

Browse files
committed
add unwrap_model for accelerator, torch.no_grad context for validation, fix accelerator.accumulate call in advanced script
1 parent 2d8ca60 commit 0565932

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,10 @@ def log_validation(
231231
autocast_ctx = torch.autocast(accelerator.device.type)
232232

233233
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
234-
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
235-
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
236-
)
234+
with torch.no_grad():
235+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
236+
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
237+
)
237238
images = []
238239
for _ in range(args.num_validation_images):
239240
with autocast_ctx:
@@ -2044,6 +2045,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20442045
pivoted_tr = True
20452046

20462047
for step, batch in enumerate(train_dataloader):
2048+
models_to_accumulate = [transformer]
2049+
if not freeze_text_encoder:
2050+
models_to_accumulate.extend([text_encoder_one])
2051+
if args.enable_t5_ti:
2052+
models_to_accumulate.extend([text_encoder_two])
20472053
if pivoted_te:
20482054
# stopping optimization of text_encoder params
20492055
optimizer.param_groups[te_idx]["lr"] = 0.0
@@ -2052,7 +2058,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20522058
logger.info(f"PIVOT TRANSFORMER {epoch}")
20532059
optimizer.param_groups[0]["lr"] = 0.0
20542060

2055-
with accelerator.accumulate(transformer):
2061+
with accelerator.accumulate(models_to_accumulate):
20562062
prompts = batch["prompts"]
20572063

20582064
# encode batch prompts when custom prompts are provided for each image -

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,10 @@ def log_validation(
185185
autocast_ctx = torch.autocast(accelerator.device.type)
186186

187187
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
188-
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
189-
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
190-
)
188+
with torch.no_grad():
189+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
190+
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
191+
)
191192
images = []
192193
for _ in range(args.num_validation_images):
193194
with autocast_ctx:
@@ -940,7 +941,7 @@ def _encode_prompt_with_t5(
940941

941942
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
942943

943-
dtype = text_encoder.dtype
944+
dtype = unwrap_model(text_encoder).dtype
944945
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
945946

946947
_, seq_len, _ = prompt_embeds.shape
@@ -983,7 +984,7 @@ def _encode_prompt_with_clip(
983984

984985
# Use pooled output of CLIPTextModel
985986
prompt_embeds = prompt_embeds.pooler_output
986-
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
987+
prompt_embeds = prompt_embeds.to(dtype=unwrap_model(text_encoder).dtype, device=device)
987988

988989
# duplicate text embeddings for each generation per prompt, using mps friendly method
989990
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1002,7 +1003,7 @@ def encode_prompt(
10021003
text_input_ids_list=None,
10031004
):
10041005
prompt = [prompt] if isinstance(prompt, str) else prompt
1005-
dtype = text_encoders[0].dtype
1006+
dtype = unwrap_model(text_encoders[0]).dtype
10061007

10071008
pooled_prompt_embeds = _encode_prompt_with_clip(
10081009
text_encoder=text_encoders[0],

0 commit comments

Comments
 (0)