Skip to content

Commit 2b9bc1d

Browse files
committed
make base code changes referred from train_instructpix2pix script in examples
1 parent 96c376a commit 2b9bc1d

File tree

1 file changed

+107
-76
lines changed

1 file changed

+107
-76
lines changed

examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py

Lines changed: 107 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
17+
"""
18+
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
19+
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
20+
"""
1821

1922
import argparse
2023
import logging
@@ -30,6 +33,7 @@
3033
import PIL
3134
import requests
3235
import torch
36+
import torch.nn as nn
3337
import torch.nn.functional as F
3438
import torch.utils.checkpoint
3539
import transformers
@@ -50,10 +54,13 @@
5054
from diffusers.training_utils import EMAModel
5155
from diffusers.utils import check_min_version, deprecate, is_wandb_available
5256
from diffusers.utils.import_utils import is_xformers_available
57+
if is_wandb_available():
58+
59+
import wandb
5360

5461

5562
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
56-
check_min_version("0.26.0.dev0")
63+
check_min_version("0.32.0.dev0")
5764

5865
logger = get_logger(__name__, log_level="INFO")
5966

@@ -62,6 +69,48 @@
6269
}
6370
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
6471

72+
def log_validation(
73+
pipeline,
74+
args,
75+
accelerator,
76+
generator,
77+
):
78+
logger.info(
79+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
80+
f" {args.validation_prompt}."
81+
)
82+
pipeline = pipeline.to(accelerator.device)
83+
pipeline.set_progress_bar_config(disable=True)
84+
85+
# run inference
86+
original_image = download_image(args.val_image_url)
87+
edited_images = []
88+
if torch.backends.mps.is_available():
89+
autocast_ctx = nullcontext()
90+
else:
91+
autocast_ctx = torch.autocast(accelerator.device.type)
92+
93+
with autocast_ctx:
94+
for _ in range(args.num_validation_images):
95+
edited_images.append(
96+
pipeline(
97+
args.validation_prompt,
98+
image=original_image,
99+
num_inference_steps=20,
100+
image_guidance_scale=1.5,
101+
guidance_scale=7,
102+
generator=generator,
103+
).images[0]
104+
)
105+
106+
for tracker in accelerator.trackers:
107+
if tracker.name == "wandb":
108+
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
109+
for edited_image in edited_images:
110+
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
111+
tracker.log({"validation": wandb_table})
112+
113+
return edited_images
65114

66115
def parse_args():
67116
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
@@ -417,11 +466,6 @@ def main():
417466

418467
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
419468

420-
if args.report_to == "wandb":
421-
if not is_wandb_available():
422-
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
423-
import wandb
424-
425469
# Make one log on every process with the configuration for debugging.
426470
logging.basicConfig(
427471
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -467,6 +511,24 @@ def main():
467511
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
468512
)
469513

514+
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
515+
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
516+
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
517+
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
518+
# initialized to zero.
519+
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
520+
in_channels = 8
521+
out_channels = unet.conv_in.out_channels
522+
unet.register_to_config(in_channels=in_channels)
523+
524+
with torch.no_grad():
525+
new_conv_in = nn.Conv2d(
526+
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
527+
)
528+
new_conv_in.weight.zero_()
529+
new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
530+
unet.conv_in = new_conv_in
531+
470532
# Freeze vae, text_encoder and unet
471533
vae.requires_grad_(False)
472534
text_encoder.requires_grad_(False)
@@ -528,6 +590,11 @@ def main():
528590
else:
529591
raise ValueError("xformers is not available. Make sure it is installed correctly")
530592

593+
def unwrap_model(model):
594+
model = accelerator.unwrap_model(model)
595+
model = model._orig_mod if is_compiled_module(model) else model
596+
return model
597+
531598
# `accelerate` 0.16.0 will have better support for customized saving
532599
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
533600
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -540,7 +607,8 @@ def save_model_hook(models, weights, output_dir):
540607
model.save_pretrained(os.path.join(output_dir, "unet"))
541608

542609
# make sure to pop weight so that corresponding model is not saved again
543-
weights.pop()
610+
if weights:
611+
weights.pop()
544612

545613
def load_model_hook(models, input_dir):
546614
if args.use_ema:
@@ -730,17 +798,22 @@ def collate_fn(examples):
730798
)
731799

732800
# Scheduler and math around the number of training steps.
733-
overrode_max_train_steps = False
734-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
801+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
802+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
735803
if args.max_train_steps is None:
736-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
737-
overrode_max_train_steps = True
804+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
805+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
806+
num_training_steps_for_scheduler = (
807+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
808+
)
809+
else:
810+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
738811

739812
lr_scheduler = get_scheduler(
740813
args.lr_scheduler,
741814
optimizer=optimizer,
742-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
743-
num_training_steps=args.max_train_steps * accelerator.num_processes,
815+
num_warmup_steps=num_warmup_steps_for_scheduler,
816+
num_training_steps=num_training_steps_for_scheduler,
744817
)
745818

746819
# Prepare everything with our `accelerator`.
@@ -765,8 +838,14 @@ def collate_fn(examples):
765838

766839
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
767840
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
768-
if overrode_max_train_steps:
841+
if args.max_train_steps is None:
769842
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
843+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
844+
logger.warning(
845+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
846+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
847+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
848+
)
770849
# Afterwards we recalculate our number of training epochs
771850
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
772851

@@ -959,45 +1038,22 @@ def collate_fn(examples):
9591038
# The models need unwrapping because for compatibility in distributed training mode.
9601039
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
9611040
args.pretrained_model_name_or_path,
962-
unet=accelerator.unwrap_model(unet),
963-
text_encoder=accelerator.unwrap_model(text_encoder),
964-
vae=accelerator.unwrap_model(vae),
1041+
unet=unwrap_model(unet),
1042+
text_encoder=unwrap_model(text_encoder),
1043+
vae=unwrap_model(vae),
9651044
revision=args.revision,
9661045
variant=args.variant,
9671046
torch_dtype=weight_dtype,
9681047
)
969-
pipeline = pipeline.to(accelerator.device)
970-
pipeline.set_progress_bar_config(disable=True)
9711048

9721049
# run inference
973-
original_image = download_image(args.val_image_url)
974-
edited_images = []
975-
if torch.backends.mps.is_available():
976-
autocast_ctx = nullcontext()
977-
else:
978-
autocast_ctx = torch.autocast(accelerator.device.type)
979-
980-
with autocast_ctx:
981-
for _ in range(args.num_validation_images):
982-
edited_images.append(
983-
pipeline(
984-
args.validation_prompt,
985-
image=original_image,
986-
num_inference_steps=20,
987-
image_guidance_scale=1.5,
988-
guidance_scale=7,
989-
generator=generator,
990-
).images[0]
991-
)
992-
993-
for tracker in accelerator.trackers:
994-
if tracker.name == "wandb":
995-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
996-
for edited_image in edited_images:
997-
wandb_table.add_data(
998-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
999-
)
1000-
tracker.log({"validation": wandb_table})
1050+
log_validation(
1051+
pipeline,
1052+
args,
1053+
accelerator,
1054+
generator,
1055+
)
1056+
10011057
if args.use_ema:
10021058
# Switch back to the original UNet parameters.
10031059
ema_unet.restore(unet.parameters())
@@ -1014,9 +1070,9 @@ def collate_fn(examples):
10141070

10151071
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
10161072
args.pretrained_model_name_or_path,
1017-
text_encoder=accelerator.unwrap_model(text_encoder),
1018-
vae=accelerator.unwrap_model(vae),
1019-
unet=unet,
1073+
text_encoder=unwrap_model(text_encoder),
1074+
vae=unwrap_model(vae),
1075+
unet=unwrap_model(unet),
10201076
revision=args.revision,
10211077
variant=args.variant,
10221078
)
@@ -1031,31 +1087,6 @@ def collate_fn(examples):
10311087
ignore_patterns=["step_*", "epoch_*"],
10321088
)
10331089

1034-
if args.validation_prompt is not None:
1035-
edited_images = []
1036-
pipeline = pipeline.to(accelerator.device)
1037-
with torch.autocast(str(accelerator.device).replace(":0", "")):
1038-
for _ in range(args.num_validation_images):
1039-
edited_images.append(
1040-
pipeline(
1041-
args.validation_prompt,
1042-
image=original_image,
1043-
num_inference_steps=20,
1044-
image_guidance_scale=1.5,
1045-
guidance_scale=7,
1046-
generator=generator,
1047-
).images[0]
1048-
)
1049-
1050-
for tracker in accelerator.trackers:
1051-
if tracker.name == "wandb":
1052-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1053-
for edited_image in edited_images:
1054-
wandb_table.add_data(
1055-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1056-
)
1057-
tracker.log({"test": wandb_table})
1058-
10591090
accelerator.end_training()
10601091

10611092

0 commit comments

Comments
 (0)