Skip to content

Commit 86cbc0f

Browse files
committed
Merge remote-tracking branch 'origin/video-loras' into video-loras
2 parents d2dd6ae + c464455 commit 86cbc0f

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

examples/research_projects/autoencoderkl/train_autoencoderkl.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def main(args):
627627
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
628628
perceptual_loss = lpips.LPIPS(net="vgg").eval()
629629
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
630+
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
630631

631632
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
632633
def unwrap_model(model):
@@ -951,13 +952,20 @@ def load_model_hook(models, input_dir):
951952
logits_fake = discriminator(reconstructions)
952953
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
953954
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
954-
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
955+
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
955956
logs = {
956-
"disc_loss": disc_loss.detach().mean().item(),
957+
"disc_loss": d_loss.detach().mean().item(),
957958
"logits_real": logits_real.detach().mean().item(),
958959
"logits_fake": logits_fake.detach().mean().item(),
959960
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
960961
}
962+
accelerator.backward(d_loss)
963+
if accelerator.sync_gradients:
964+
params_to_clip = discriminator.parameters()
965+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
966+
disc_optimizer.step()
967+
disc_lr_scheduler.step()
968+
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
961969
# Checks if the accelerator has performed an optimization step behind the scenes
962970
if accelerator.sync_gradients:
963971
progress_bar.update(1)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4268,11 +4268,11 @@ def _maybe_expand_t2v_lora_for_i2v(
42684268

42694269
for i in range(num_blocks):
42704270
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4271-
state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4272-
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"]
4271+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4272+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
42734273
)
4274-
state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4275-
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"]
4274+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4275+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
42764276
)
42774277

42784278
return state_dict

0 commit comments

Comments
 (0)