Skip to content

Commit f11b922

Browse files
Modularize Dreambooth LoRA SD inferencing during and after training (#6654)
* modulize log validation * run make style and refactor wanddb support * remove redundant initialization --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 3dd4168 commit f11b922

File tree

1 file changed

+84
-89
lines changed

1 file changed

+84
-89
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 84 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666
from diffusers.utils.torch_utils import is_compiled_module
6767

6868

69+
if is_wandb_available():
70+
import wandb
71+
6972
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
7073
check_min_version("0.27.0.dev0")
7174

@@ -113,6 +116,71 @@ def save_model_card(
113116
model_card.save(os.path.join(repo_folder, "README.md"))
114117

115118

119+
def log_validation(
120+
pipeline,
121+
args,
122+
accelerator,
123+
pipeline_args,
124+
epoch,
125+
is_final_validation=False,
126+
):
127+
logger.info(
128+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
129+
f" {args.validation_prompt}."
130+
)
131+
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
132+
scheduler_args = {}
133+
134+
if "variance_type" in pipeline.scheduler.config:
135+
variance_type = pipeline.scheduler.config.variance_type
136+
137+
if variance_type in ["learned", "learned_range"]:
138+
variance_type = "fixed_small"
139+
140+
scheduler_args["variance_type"] = variance_type
141+
142+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
143+
144+
pipeline = pipeline.to(accelerator.device)
145+
pipeline.set_progress_bar_config(disable=True)
146+
147+
# run inference
148+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
149+
150+
if args.validation_images is None:
151+
images = []
152+
for _ in range(args.num_validation_images):
153+
with torch.cuda.amp.autocast():
154+
image = pipeline(**pipeline_args, generator=generator).images[0]
155+
images.append(image)
156+
else:
157+
images = []
158+
for image in args.validation_images:
159+
image = Image.open(image)
160+
with torch.cuda.amp.autocast():
161+
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
162+
images.append(image)
163+
164+
for tracker in accelerator.trackers:
165+
phase_name = "test" if is_final_validation else "validation"
166+
if tracker.name == "tensorboard":
167+
np_images = np.stack([np.asarray(img) for img in images])
168+
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
169+
if tracker.name == "wandb":
170+
tracker.log(
171+
{
172+
phase_name: [
173+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
174+
]
175+
}
176+
)
177+
178+
del pipeline
179+
torch.cuda.empty_cache()
180+
181+
return images
182+
183+
116184
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
117185
text_encoder_config = PretrainedConfig.from_pretrained(
118186
pretrained_model_name_or_path,
@@ -684,7 +752,6 @@ def main(args):
684752
if args.report_to == "wandb":
685753
if not is_wandb_available():
686754
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
687-
import wandb
688755

689756
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
690757
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
@@ -1265,10 +1332,6 @@ def compute_text_embeddings(prompt):
12651332

12661333
if accelerator.is_main_process:
12671334
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1268-
logger.info(
1269-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1270-
f" {args.validation_prompt}."
1271-
)
12721335
# create pipeline
12731336
pipeline = DiffusionPipeline.from_pretrained(
12741337
args.pretrained_model_name_or_path,
@@ -1279,26 +1342,6 @@ def compute_text_embeddings(prompt):
12791342
torch_dtype=weight_dtype,
12801343
)
12811344

1282-
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1283-
scheduler_args = {}
1284-
1285-
if "variance_type" in pipeline.scheduler.config:
1286-
variance_type = pipeline.scheduler.config.variance_type
1287-
1288-
if variance_type in ["learned", "learned_range"]:
1289-
variance_type = "fixed_small"
1290-
1291-
scheduler_args["variance_type"] = variance_type
1292-
1293-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1294-
pipeline.scheduler.config, **scheduler_args
1295-
)
1296-
1297-
pipeline = pipeline.to(accelerator.device)
1298-
pipeline.set_progress_bar_config(disable=True)
1299-
1300-
# run inference
1301-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
13021345
if args.pre_compute_text_embeddings:
13031346
pipeline_args = {
13041347
"prompt_embeds": validation_prompt_encoder_hidden_states,
@@ -1307,36 +1350,13 @@ def compute_text_embeddings(prompt):
13071350
else:
13081351
pipeline_args = {"prompt": args.validation_prompt}
13091352

1310-
if args.validation_images is None:
1311-
images = []
1312-
for _ in range(args.num_validation_images):
1313-
with torch.cuda.amp.autocast():
1314-
image = pipeline(**pipeline_args, generator=generator).images[0]
1315-
images.append(image)
1316-
else:
1317-
images = []
1318-
for image in args.validation_images:
1319-
image = Image.open(image)
1320-
with torch.cuda.amp.autocast():
1321-
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
1322-
images.append(image)
1323-
1324-
for tracker in accelerator.trackers:
1325-
if tracker.name == "tensorboard":
1326-
np_images = np.stack([np.asarray(img) for img in images])
1327-
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1328-
if tracker.name == "wandb":
1329-
tracker.log(
1330-
{
1331-
"validation": [
1332-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1333-
for i, image in enumerate(images)
1334-
]
1335-
}
1336-
)
1337-
1338-
del pipeline
1339-
torch.cuda.empty_cache()
1353+
images = log_validation(
1354+
pipeline,
1355+
args,
1356+
accelerator,
1357+
pipeline_args,
1358+
epoch,
1359+
)
13401360

13411361
# Save the lora layers
13421362
accelerator.wait_for_everyone()
@@ -1364,46 +1384,21 @@ def compute_text_embeddings(prompt):
13641384
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
13651385
)
13661386

1367-
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1368-
scheduler_args = {}
1369-
1370-
if "variance_type" in pipeline.scheduler.config:
1371-
variance_type = pipeline.scheduler.config.variance_type
1372-
1373-
if variance_type in ["learned", "learned_range"]:
1374-
variance_type = "fixed_small"
1375-
1376-
scheduler_args["variance_type"] = variance_type
1377-
1378-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1379-
1380-
pipeline = pipeline.to(accelerator.device)
1381-
13821387
# load attention processors
13831388
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
13841389

13851390
# run inference
13861391
images = []
13871392
if args.validation_prompt and args.num_validation_images > 0:
1388-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1389-
images = [
1390-
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1391-
for _ in range(args.num_validation_images)
1392-
]
1393-
1394-
for tracker in accelerator.trackers:
1395-
if tracker.name == "tensorboard":
1396-
np_images = np.stack([np.asarray(img) for img in images])
1397-
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1398-
if tracker.name == "wandb":
1399-
tracker.log(
1400-
{
1401-
"test": [
1402-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1403-
for i, image in enumerate(images)
1404-
]
1405-
}
1406-
)
1393+
pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
1394+
images = log_validation(
1395+
pipeline,
1396+
args,
1397+
accelerator,
1398+
pipeline_args,
1399+
epoch,
1400+
is_final_validation=True,
1401+
)
14071402

14081403
if args.push_to_hub:
14091404
save_model_card(

0 commit comments

Comments
 (0)