Skip to content

Commit 83062fb

Browse files
[Advanced DreamBooth LoRA SDXL] Support EDM-style training (follow up of #7126) (#7182)
* add edm style training * style * finish adding edm training feature * import fix * fix latents mean * minor adjustments * add edm to readme * style * fix autocast and scheduler config issues when using edm * style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent b6d7e31 commit 83062fb

File tree

2 files changed

+208
-31
lines changed

2 files changed

+208
-31
lines changed

examples/advanced_diffusion_training/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git
259259
**Inference**
260260
The inference is the same as if you train a regular LoRA 🤗
261261

262+
## Conducting EDM-style training
263+
264+
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
265+
266+
simply set:
267+
268+
```diff
269+
+ --do_edm_style_training \
270+
```
271+
272+
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
273+
274+
```bash
275+
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
276+
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
277+
--dataset_name="linoyts/3d_icon" \
278+
--instance_prompt="3d icon in the style of TOK" \
279+
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
280+
--output_dir="3d-icon-SDXL-LoRA" \
281+
--do_edm_style_training \
282+
--caption_column="prompt" \
283+
--mixed_precision="bf16" \
284+
--resolution=1024 \
285+
--train_batch_size=3 \
286+
--repeats=1 \
287+
--report_to="wandb"\
288+
--gradient_accumulation_steps=1 \
289+
--gradient_checkpointing \
290+
--learning_rate=1.0 \
291+
--text_encoder_lr=1.0 \
292+
--optimizer="prodigy"\
293+
--train_text_encoder_ti\
294+
--train_text_encoder_ti_frac=0.5\
295+
--lr_scheduler="constant" \
296+
--lr_warmup_steps=0 \
297+
--rank=8 \
298+
--max_train_steps=1000 \
299+
--checkpointing_steps=2000 \
300+
--seed="0" \
301+
--push_to_hub
302+
```
303+
304+
> [!CAUTION]
305+
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
262306
263307
### Tips and Tricks
264308
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 164 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# See the License for the specific language governing permissions and
1515

1616
import argparse
17+
import contextlib
1718
import gc
1819
import hashlib
1920
import itertools
21+
import json
2022
import logging
2123
import math
2224
import os
@@ -37,7 +39,7 @@
3739
from accelerate import Accelerator
3840
from accelerate.logging import get_logger
3941
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
40-
from huggingface_hub import create_repo, upload_folder
42+
from huggingface_hub import create_repo, hf_hub_download, upload_folder
4143
from packaging import version
4244
from peft import LoraConfig, set_peft_model_state_dict
4345
from peft.utils import get_peft_model_state_dict
@@ -55,6 +57,8 @@
5557
AutoencoderKL,
5658
DDPMScheduler,
5759
DPMSolverMultistepScheduler,
60+
EDMEulerScheduler,
61+
EulerDiscreteScheduler,
5862
StableDiffusionXLPipeline,
5963
UNet2DConditionModel,
6064
)
@@ -79,6 +83,20 @@
7983
logger = get_logger(__name__)
8084

8185

86+
def determine_scheduler_type(pretrained_model_name_or_path, revision):
87+
model_index_filename = "model_index.json"
88+
if os.path.isdir(pretrained_model_name_or_path):
89+
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
90+
else:
91+
model_index = hf_hub_download(
92+
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
93+
)
94+
95+
with open(model_index, "r") as f:
96+
scheduler_type = json.load(f)["scheduler"][1]
97+
return scheduler_type
98+
99+
82100
def save_model_card(
83101
repo_id: str,
84102
use_dora: bool,
@@ -370,6 +388,11 @@ def parse_args(input_args=None):
370388
" `args.validation_prompt` multiple times: `args.num_validation_images`."
371389
),
372390
)
391+
parser.add_argument(
392+
"--do_edm_style_training",
393+
action="store_true",
394+
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
395+
)
373396
parser.add_argument(
374397
"--with_prior_preservation",
375398
default=False,
@@ -1117,6 +1140,8 @@ def main(args):
11171140
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
11181141
" Please use `huggingface-cli login` to authenticate with the Hub."
11191142
)
1143+
if args.do_edm_style_training and args.snr_gamma is not None:
1144+
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
11201145

11211146
logging_dir = Path(args.output_dir, args.logging_dir)
11221147

@@ -1234,7 +1259,19 @@ def main(args):
12341259
)
12351260

12361261
# Load scheduler and models
1237-
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
1262+
scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
1263+
if "EDM" in scheduler_type:
1264+
args.do_edm_style_training = True
1265+
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
1266+
logger.info("Performing EDM-style training!")
1267+
elif args.do_edm_style_training:
1268+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
1269+
args.pretrained_model_name_or_path, subfolder="scheduler"
1270+
)
1271+
logger.info("Performing EDM-style training!")
1272+
else:
1273+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
1274+
12381275
text_encoder_one = text_encoder_cls_one.from_pretrained(
12391276
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
12401277
)
@@ -1252,7 +1289,12 @@ def main(args):
12521289
revision=args.revision,
12531290
variant=args.variant,
12541291
)
1255-
vae_scaling_factor = vae.config.scaling_factor
1292+
latents_mean = latents_std = None
1293+
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
1294+
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
1295+
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
1296+
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
1297+
12561298
unet = UNet2DConditionModel.from_pretrained(
12571299
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
12581300
)
@@ -1790,6 +1832,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17901832
disable=not accelerator.is_local_main_process,
17911833
)
17921834

1835+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1836+
# TODO: revisit other sampling algorithms
1837+
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
1838+
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
1839+
timesteps = timesteps.to(accelerator.device)
1840+
1841+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1842+
1843+
sigma = sigmas[step_indices].flatten()
1844+
while len(sigma.shape) < n_dim:
1845+
sigma = sigma.unsqueeze(-1)
1846+
return sigma
1847+
17931848
if args.train_text_encoder:
17941849
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
17951850
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
@@ -1841,9 +1896,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18411896
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
18421897
model_input = vae.encode(pixel_values).latent_dist.sample()
18431898

1844-
model_input = model_input * vae_scaling_factor
1845-
if args.pretrained_vae_model_name_or_path is None:
1846-
model_input = model_input.to(weight_dtype)
1899+
if latents_mean is None and latents_std is None:
1900+
model_input = model_input * vae.config.scaling_factor
1901+
if args.pretrained_vae_model_name_or_path is None:
1902+
model_input = model_input.to(weight_dtype)
1903+
else:
1904+
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
1905+
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
1906+
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
1907+
model_input = model_input.to(dtype=weight_dtype)
18471908

18481909
# Sample noise that we'll add to the latents
18491910
noise = torch.randn_like(model_input)
@@ -1854,15 +1915,32 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18541915
)
18551916

18561917
bsz = model_input.shape[0]
1918+
18571919
# Sample a random timestep for each image
1858-
timesteps = torch.randint(
1859-
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1860-
)
1861-
timesteps = timesteps.long()
1920+
if not args.do_edm_style_training:
1921+
timesteps = torch.randint(
1922+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1923+
)
1924+
timesteps = timesteps.long()
1925+
else:
1926+
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
1927+
# instead of discrete timesteps, so here we sample indices to get the noise levels
1928+
# from `scheduler.timesteps`
1929+
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
1930+
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
18621931

18631932
# Add noise to the model input according to the noise magnitude at each timestep
18641933
# (this is the forward diffusion process)
18651934
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1935+
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
1936+
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
1937+
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1938+
if args.do_edm_style_training:
1939+
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
1940+
if "EDM" in scheduler_type:
1941+
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
1942+
else:
1943+
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
18661944

18671945
# time ids
18681946
add_time_ids = torch.cat(
@@ -1888,7 +1966,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18881966
}
18891967
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
18901968
model_pred = unet(
1891-
noisy_model_input,
1969+
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
18921970
timesteps,
18931971
prompt_embeds_input,
18941972
added_cond_kwargs=unet_added_conditions,
@@ -1906,14 +1984,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19061984
)
19071985
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
19081986
model_pred = unet(
1909-
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
1987+
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
1988+
timesteps,
1989+
prompt_embeds_input,
1990+
added_cond_kwargs=unet_added_conditions,
19101991
).sample
19111992

1993+
weighting = None
1994+
if args.do_edm_style_training:
1995+
# Similar to the input preconditioning, the model predictions are also preconditioned
1996+
# on noised model inputs (before preconditioning) and the sigmas.
1997+
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1998+
if "EDM" in scheduler_type:
1999+
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
2000+
else:
2001+
if noise_scheduler.config.prediction_type == "epsilon":
2002+
model_pred = model_pred * (-sigmas) + noisy_model_input
2003+
elif noise_scheduler.config.prediction_type == "v_prediction":
2004+
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
2005+
noisy_model_input / (sigmas**2 + 1)
2006+
)
2007+
# We are not doing weighting here because it tends result in numerical problems.
2008+
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
2009+
# There might be other alternatives for weighting as well:
2010+
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
2011+
if "EDM" not in scheduler_type:
2012+
weighting = (sigmas**-2.0).float()
2013+
19122014
# Get the target for loss depending on the prediction type
19132015
if noise_scheduler.config.prediction_type == "epsilon":
1914-
target = noise
2016+
target = model_input if args.do_edm_style_training else noise
19152017
elif noise_scheduler.config.prediction_type == "v_prediction":
1916-
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
2018+
target = (
2019+
model_input
2020+
if args.do_edm_style_training
2021+
else noise_scheduler.get_velocity(model_input, noise, timesteps)
2022+
)
19172023
else:
19182024
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
19192025

@@ -1923,10 +2029,28 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19232029
target, target_prior = torch.chunk(target, 2, dim=0)
19242030

19252031
# Compute prior loss
1926-
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
2032+
if weighting is not None:
2033+
prior_loss = torch.mean(
2034+
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
2035+
target_prior.shape[0], -1
2036+
),
2037+
1,
2038+
)
2039+
prior_loss = prior_loss.mean()
2040+
else:
2041+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
19272042

19282043
if args.snr_gamma is None:
1929-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
2044+
if weighting is not None:
2045+
loss = torch.mean(
2046+
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
2047+
target.shape[0], -1
2048+
),
2049+
1,
2050+
)
2051+
loss = loss.mean()
2052+
else:
2053+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
19302054
else:
19312055
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
19322056
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
@@ -2049,26 +2173,32 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
20492173
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
20502174
scheduler_args = {}
20512175

2052-
if "variance_type" in pipeline.scheduler.config:
2053-
variance_type = pipeline.scheduler.config.variance_type
2176+
if not args.do_edm_style_training:
2177+
if "variance_type" in pipeline.scheduler.config:
2178+
variance_type = pipeline.scheduler.config.variance_type
20542179

2055-
if variance_type in ["learned", "learned_range"]:
2056-
variance_type = "fixed_small"
2180+
if variance_type in ["learned", "learned_range"]:
2181+
variance_type = "fixed_small"
20572182

2058-
scheduler_args["variance_type"] = variance_type
2183+
scheduler_args["variance_type"] = variance_type
20592184

2060-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
2061-
pipeline.scheduler.config, **scheduler_args
2062-
)
2185+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
2186+
pipeline.scheduler.config, **scheduler_args
2187+
)
20632188

20642189
pipeline = pipeline.to(accelerator.device)
20652190
pipeline.set_progress_bar_config(disable=True)
20662191

20672192
# run inference
20682193
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
20692194
pipeline_args = {"prompt": args.validation_prompt}
2195+
inference_ctx = (
2196+
contextlib.nullcontext()
2197+
if "playground" in args.pretrained_model_name_or_path
2198+
else torch.cuda.amp.autocast()
2199+
)
20702200

2071-
with torch.cuda.amp.autocast():
2201+
with inference_ctx:
20722202
images = [
20732203
pipeline(**pipeline_args, generator=generator).images[0]
20742204
for _ in range(args.num_validation_images)
@@ -2144,15 +2274,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
21442274
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
21452275
scheduler_args = {}
21462276

2147-
if "variance_type" in pipeline.scheduler.config:
2148-
variance_type = pipeline.scheduler.config.variance_type
2277+
if not args.do_edm_style_training:
2278+
if "variance_type" in pipeline.scheduler.config:
2279+
variance_type = pipeline.scheduler.config.variance_type
21492280

2150-
if variance_type in ["learned", "learned_range"]:
2151-
variance_type = "fixed_small"
2281+
if variance_type in ["learned", "learned_range"]:
2282+
variance_type = "fixed_small"
21522283

2153-
scheduler_args["variance_type"] = variance_type
2284+
scheduler_args["variance_type"] = variance_type
21542285

2155-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
2286+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
2287+
pipeline.scheduler.config, **scheduler_args
2288+
)
21562289

21572290
# load attention processors
21582291
pipeline.load_lora_weights(args.output_dir)

0 commit comments

Comments
 (0)