Skip to content

Commit 9df3d84

Browse files
dg845sayakpaul
andauthored
Fix LCM distillation bug when creating the guidance scale embeddings using multiple GPUs. (#6279)
Fix bug when creating the guidance embeddings using multiple GPUs. Co-authored-by: Sayak Paul <[email protected]>
1 parent c751449 commit 9df3d84

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/consistency_distillation/train_lcm_distill_sd_wds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ def main(args):
889889
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
890890
if teacher_unet.config.time_cond_proj_dim is None:
891891
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
892+
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
892893
unet = UNet2DConditionModel(**teacher_unet.config)
893894
# load teacher_unet weights into unet
894895
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
@@ -1175,7 +1176,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
11751176

11761177
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
11771178
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
1178-
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
1179+
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
11791180
w = w.reshape(bsz, 1, 1, 1)
11801181
# Move to U-Net device and dtype
11811182
w = w.to(device=latents.device, dtype=latents.dtype)

examples/consistency_distillation/train_lcm_distill_sdxl_wds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,7 @@ def main(args):
948948
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
949949
if teacher_unet.config.time_cond_proj_dim is None:
950950
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
951+
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
951952
unet = UNet2DConditionModel(**teacher_unet.config)
952953
# load teacher_unet weights into unet
953954
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
@@ -1273,7 +1274,7 @@ def compute_embeddings(
12731274

12741275
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
12751276
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
1276-
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
1277+
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
12771278
w = w.reshape(bsz, 1, 1, 1)
12781279
# Move to U-Net device and dtype
12791280
w = w.to(device=latents.device, dtype=latents.dtype)

0 commit comments

Comments
 (0)