Skip to content

Commit 17cece0

Browse files
authored
Fix bug in LCM Distillation Scripts when args.unet_time_cond_proj_dim is used (#6523)
* Fix bug where unet's time_cond_proj_dim is not set correctly if using args.unet_time_cond_proj_dim. * make style
1 parent a551ddf commit 17cece0

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

examples/consistency_distillation/train_lcm_distill_sd_wds.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,12 @@ def main(args):
921921

922922
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
923923
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
924-
if teacher_unet.config.time_cond_proj_dim is None:
925-
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
926-
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
927-
unet = UNet2DConditionModel(**teacher_unet.config)
924+
time_cond_proj_dim = (
925+
teacher_unet.config.time_cond_proj_dim
926+
if teacher_unet.config.time_cond_proj_dim is not None
927+
else args.unet_time_cond_proj_dim
928+
)
929+
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
928930
# load teacher_unet weights into unet
929931
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
930932
unet.train()

examples/consistency_distillation/train_lcm_distill_sdxl_wds.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -980,10 +980,12 @@ def main(args):
980980

981981
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
982982
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
983-
if teacher_unet.config.time_cond_proj_dim is None:
984-
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
985-
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
986-
unet = UNet2DConditionModel(**teacher_unet.config)
983+
time_cond_proj_dim = (
984+
teacher_unet.config.time_cond_proj_dim
985+
if teacher_unet.config.time_cond_proj_dim is not None
986+
else args.unet_time_cond_proj_dim
987+
)
988+
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
987989
# load teacher_unet weights into unet
988990
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
989991
unet.train()

0 commit comments

Comments
 (0)