Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
298709e
remove custom scheduler
linoytsaban Mar 18, 2025
364f478
update requirements.txt
linoytsaban Mar 18, 2025
90e9517
log_validation with mixed precision
linoytsaban Mar 18, 2025
d90d7f0
Merge branch 'main' into flux_lora_advanced
linoytsaban Mar 18, 2025
bdd6cae
add intermediate embeddings saving when checkpointing is enabled
linoytsaban Mar 18, 2025
c8e165b
remove comment
linoytsaban Mar 18, 2025
d434db3
Merge remote-tracking branch 'origin/flux_lora_advanced' into flux_lo…
linoytsaban Mar 18, 2025
710fcae
fix validation
linoytsaban Mar 18, 2025
2d8ca60
Merge branch 'main' into flux_lora_advanced
linoytsaban Mar 19, 2025
0565932
add unwrap_model for accelerator, torch.no_grad context for validatio…
linoytsaban Mar 19, 2025
ba4dece
revert unwrap_model change temp
linoytsaban Mar 19, 2025
c155f22
add .module to address distributed training bug + replace accelerator…
linoytsaban Mar 19, 2025
9c4368d
changes to align advanced script with canonical script
linoytsaban Mar 19, 2025
7492e92
make changes for distributed training + unify unwrap_model calls in a…
linoytsaban Mar 19, 2025
0729c66
add module.dtype fix to dreambooth script
linoytsaban Mar 19, 2025
cc1d2ad
unify unwrap_model calls in dreambooth script
linoytsaban Mar 19, 2025
07c2974
Merge branch 'main' into flux_lora_advanced
linoytsaban Mar 19, 2025
603b57c
Merge branch 'main' into flux_lora_advanced
linoytsaban Mar 20, 2025
e5636f0
Merge branch 'main' into flux_lora_advanced
linoytsaban Mar 27, 2025
8bf49c7
fix condition in validation run
linoytsaban Mar 27, 2025
b211eea
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 2, 2025
9b2917f
mixed precision
linoytsaban Apr 3, 2025
22046d1
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 3, 2025
f1af7e2
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 4, 2025
8dc7005
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 4, 2025
d8ef75f
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 7, 2025
bfd8d45
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 8, 2025
5d249a7
Update examples/advanced_diffusion_training/train_dreambooth_lora_flu…
linoytsaban Apr 8, 2025
a4b1e7f
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 8, 2025
57ee3cf
smol style change
linoytsaban Apr 8, 2025
8b991a5
change autocast
linoytsaban Apr 8, 2025
bfd1df6
Apply style fixes
github-actions[bot] Apr 8, 2025
c978ca3
Merge branch 'main' into flux_lora_advanced
linoytsaban Apr 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/advanced_diffusion_training/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
accelerate>=0.16.0
accelerate>=0.31.0
torchvision
transformers>=4.25.1
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft==0.7.0
peft>=0.11.1
sentencepiece
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,20 @@ def log_validation(

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()
autocast_ctx = torch.autocast(accelerator.device.type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only needed for the intermediate validation. Do we need to check for that?

Copy link
Collaborator Author

@linoytsaban linoytsaban Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you're right, tested it now and seems to work as expected, changed it now


with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down Expand Up @@ -1265,109 +1275,6 @@ def encode_prompt(

return prompt_embeds, pooled_prompt_embeds, text_ids


# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000

# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now

# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)

# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()

# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")

self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing

# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
pass

def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]

# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()

return weights

def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)

return sigma

def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1

# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise

# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input

def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample

def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))

# Scale and reverse the values to go from 1000 to 0
timesteps = (1 - t) * 1000

# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)

self.timesteps = timesteps.to(device=device)

return timesteps


def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
Expand Down Expand Up @@ -1499,7 +1406,7 @@ def main(args):
)

# Load scheduler and models
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
Expand Down Expand Up @@ -2288,16 +2195,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator.backward(loss)
if accelerator.sync_gradients:
if not freeze_text_encoder:
if args.train_text_encoder:
if args.train_text_encoder: # text encoder tuning
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
elif pure_textual_inversion:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
if args.enable_t5_ti:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(
text_encoder_one.parameters()
)
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
)
if args.enable_t5_ti:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(transformer.parameters(),
text_encoder_one.parameters())
else:
params_to_clip = itertools.chain(transformer.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
Expand Down Expand Up @@ -2339,6 +2255,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors")
logger.info(f"Saved state to {save_path}")

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Expand All @@ -2351,8 +2270,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if freeze_text_encoder:
if freeze_text_encoder: # no text encoder one, two optimizations
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)

pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand All @@ -2378,9 +2300,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
elif args.train_text_encoder:
del text_encoder_two
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
22 changes: 15 additions & 7 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,25 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand All @@ -203,8 +212,7 @@ def log_validation(
)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

return images

Expand Down
Loading