Skip to content

Commit 3bc3b48

Browse files
satani99sayakpaul
andauthored
Modularize train_text_to_image_lora SD inferencing during and after training in example (#8283)
* Modularized the train_lora file * Modularized the train_lora file * Modularized the train_lora file * Modularized the train_lora file * Modularized the train_lora file --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 581d8aa commit 3bc3b48

File tree

1 file changed

+61
-85
lines changed

1 file changed

+61
-85
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 61 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
from diffusers.utils.torch_utils import is_compiled_module
5353

5454

55+
if is_wandb_available():
56+
import wandb
57+
5558
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5659
check_min_version("0.29.0.dev0")
5760

@@ -99,6 +102,48 @@ def save_model_card(
99102
model_card.save(os.path.join(repo_folder, "README.md"))
100103

101104

105+
def log_validation(
106+
pipeline,
107+
args,
108+
accelerator,
109+
epoch,
110+
is_final_validation=False,
111+
):
112+
logger.info(
113+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
114+
f" {args.validation_prompt}."
115+
)
116+
pipeline = pipeline.to(accelerator.device)
117+
pipeline.set_progress_bar_config(disable=True)
118+
generator = torch.Generator(device=accelerator.device)
119+
if args.seed is not None:
120+
generator = generator.manual_seed(args.seed)
121+
images = []
122+
if torch.backends.mps.is_available():
123+
autocast_ctx = nullcontext()
124+
else:
125+
autocast_ctx = torch.autocast(accelerator.device.type)
126+
127+
with autocast_ctx:
128+
for _ in range(args.num_validation_images):
129+
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
130+
131+
for tracker in accelerator.trackers:
132+
phase_name = "test" if is_final_validation else "validation"
133+
if tracker.name == "tensorboard":
134+
np_images = np.stack([np.asarray(img) for img in images])
135+
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
136+
if tracker.name == "wandb":
137+
tracker.log(
138+
{
139+
phase_name: [
140+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
141+
]
142+
}
143+
)
144+
return images
145+
146+
102147
def parse_args():
103148
parser = argparse.ArgumentParser(description="Simple example of a training script.")
104149
parser.add_argument(
@@ -414,11 +459,6 @@ def main():
414459
if torch.backends.mps.is_available():
415460
accelerator.native_amp = False
416461

417-
if args.report_to == "wandb":
418-
if not is_wandb_available():
419-
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
420-
import wandb
421-
422462
# Make one log on every process with the configuration for debugging.
423463
logging.basicConfig(
424464
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -864,10 +904,6 @@ def collate_fn(examples):
864904

865905
if accelerator.is_main_process:
866906
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
867-
logger.info(
868-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
869-
f" {args.validation_prompt}."
870-
)
871907
# create pipeline
872908
pipeline = DiffusionPipeline.from_pretrained(
873909
args.pretrained_model_name_or_path,
@@ -876,38 +912,7 @@ def collate_fn(examples):
876912
variant=args.variant,
877913
torch_dtype=weight_dtype,
878914
)
879-
pipeline = pipeline.to(accelerator.device)
880-
pipeline.set_progress_bar_config(disable=True)
881-
882-
# run inference
883-
generator = torch.Generator(device=accelerator.device)
884-
if args.seed is not None:
885-
generator = generator.manual_seed(args.seed)
886-
images = []
887-
if torch.backends.mps.is_available():
888-
autocast_ctx = nullcontext()
889-
else:
890-
autocast_ctx = torch.autocast(accelerator.device.type)
891-
892-
with autocast_ctx:
893-
for _ in range(args.num_validation_images):
894-
images.append(
895-
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
896-
)
897-
898-
for tracker in accelerator.trackers:
899-
if tracker.name == "tensorboard":
900-
np_images = np.stack([np.asarray(img) for img in images])
901-
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
902-
if tracker.name == "wandb":
903-
tracker.log(
904-
{
905-
"validation": [
906-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
907-
for i, image in enumerate(images)
908-
]
909-
}
910-
)
915+
images = log_validation(pipeline, args, accelerator, epoch)
911916

912917
del pipeline
913918
torch.cuda.empty_cache()
@@ -925,21 +930,6 @@ def collate_fn(examples):
925930
safe_serialization=True,
926931
)
927932

928-
if args.push_to_hub:
929-
save_model_card(
930-
repo_id,
931-
images=images,
932-
base_model=args.pretrained_model_name_or_path,
933-
dataset_name=args.dataset_name,
934-
repo_folder=args.output_dir,
935-
)
936-
upload_folder(
937-
repo_id=repo_id,
938-
folder_path=args.output_dir,
939-
commit_message="End of training",
940-
ignore_patterns=["step_*", "epoch_*"],
941-
)
942-
943933
# Final inference
944934
# Load previous pipeline
945935
if args.validation_prompt is not None:
@@ -949,41 +939,27 @@ def collate_fn(examples):
949939
variant=args.variant,
950940
torch_dtype=weight_dtype,
951941
)
952-
pipeline = pipeline.to(accelerator.device)
953942

954943
# load attention processors
955944
pipeline.load_lora_weights(args.output_dir)
956945

957946
# run inference
958-
generator = torch.Generator(device=accelerator.device)
959-
if args.seed is not None:
960-
generator = generator.manual_seed(args.seed)
961-
images = []
962-
if torch.backends.mps.is_available():
963-
autocast_ctx = nullcontext()
964-
else:
965-
autocast_ctx = torch.autocast(accelerator.device.type)
966-
967-
with autocast_ctx:
968-
for _ in range(args.num_validation_images):
969-
images.append(
970-
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
971-
)
947+
images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
972948

973-
for tracker in accelerator.trackers:
974-
if len(images) != 0:
975-
if tracker.name == "tensorboard":
976-
np_images = np.stack([np.asarray(img) for img in images])
977-
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
978-
if tracker.name == "wandb":
979-
tracker.log(
980-
{
981-
"test": [
982-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
983-
for i, image in enumerate(images)
984-
]
985-
}
986-
)
949+
if args.push_to_hub:
950+
save_model_card(
951+
repo_id,
952+
images=images,
953+
base_model=args.pretrained_model_name_or_path,
954+
dataset_name=args.dataset_name,
955+
repo_folder=args.output_dir,
956+
)
957+
upload_folder(
958+
repo_id=repo_id,
959+
folder_path=args.output_dir,
960+
commit_message="End of training",
961+
ignore_patterns=["step_*", "epoch_*"],
962+
)
987963

988964
accelerator.end_training()
989965

0 commit comments

Comments
 (0)