Skip to content

Commit 296ecdd

Browse files
committed
refactor variable name and freezing unet
1 parent e8458f6 commit 296ecdd

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,7 @@ def main():
590590
weight_dtype = torch.bfloat16
591591

592592
# Freeze the unet parameters before adding adapters
593-
for param in unet.parameters():
594-
param.requires_grad_(False)
593+
unet.requires_grad_(False)
595594

596595
unet_lora_config = LoraConfig(
597596
r=args.rank,
@@ -628,7 +627,7 @@ def main():
628627
else:
629628
raise ValueError("xformers is not available. Make sure it is installed correctly")
630629

631-
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
630+
trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
632631

633632
def unwrap_model(model):
634633
model = accelerator.unwrap_model(model)
@@ -699,7 +698,7 @@ def load_model_hook(models, input_dir):
699698

700699
# train on only lora_layers
701700
optimizer = optimizer_cls(
702-
lora_layers,
701+
trainable_params,
703702
lr=args.learning_rate,
704703
betas=(args.adam_beta1, args.adam_beta2),
705704
weight_decay=args.adam_weight_decay,
@@ -1014,15 +1013,15 @@ def collate_fn(examples):
10141013
# Backpropagate
10151014
accelerator.backward(loss)
10161015
if accelerator.sync_gradients:
1017-
accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm)
1016+
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
10181017
optimizer.step()
10191018
lr_scheduler.step()
10201019
optimizer.zero_grad()
10211020

10221021
# Checks if the accelerator has performed an optimization step behind the scenes
10231022
if accelerator.sync_gradients:
10241023
if args.use_ema:
1025-
ema_unet.step(lora_layers)
1024+
ema_unet.step(trainable_params)
10261025
progress_bar.update(1)
10271026
global_step += 1
10281027
accelerator.log({"train_loss": train_loss}, step=global_step)

0 commit comments

Comments
 (0)