diff --git a/examples/research_projects/flux_lora_quantization/README.md b/examples/research_projects/flux_lora_quantization/README.md new file mode 100644 index 000000000000..ffec85550e51 --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/README.md @@ -0,0 +1,166 @@ +## LoRA fine-tuning Flux.1 Dev with quantization + +> [!NOTE] +> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further. + +This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow: + +* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file. +* `train_dreambooth_lora_flux_miniature.py` takes care of training: + * Since we already precomputed the text embeddings, we don't load the text encoders. + * We load the VAE and use it to precompute the image latents and we then delete it. + * Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training. + * Add LoRA adapter layers to it and then ensure they are kept in FP32 precision. + * Train! + +To run training in a memory-optimized manner, we additionally use: + +* 8Bit Adam +* Gradient checkpointing + +We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow. + +## Training + +Ensure you have installed the required libraries: + +```bash +pip install -U transformers accelerate bitsandbytes peft datasets +pip install git+https://github.com/huggingface/diffusers -U +``` + +Now, compute the text embeddings: + +```bash +python compute_embeddings.py +``` + +It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model: + +```bash +huggingface-cli +``` + +Then launch: + +```bash +accelerate launch --config_file=accelerate.yaml \ + train_dreambooth_lora_flux_miniature.py \ + --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ + --data_df_path="embeddings.parquet" \ + --output_dir="yarn_art_lora_flux_nf4" \ + --mixed_precision="fp16" \ + --use_8bit_adam \ + --weighting_scheme="none" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --learning_rate=1e-4 \ + --guidance_scale=1 \ + --report_to="wandb" \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --cache_latents \ + --rank=4 \ + --max_train_steps=700 \ + --seed="0" +``` + +We can direcly pass a quantized checkpoint path, too: + +```diff ++ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg" +``` + +Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`. + +We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed: + +```bash +pip install -Uq deepspeed +``` + +And then launch: + +```bash +accelerate launch --config_file=ds2.yaml \ + train_dreambooth_lora_flux_miniature.py \ + --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ + --data_df_path="embeddings.parquet" \ + --output_dir="yarn_art_lora_flux_nf4" \ + --mixed_precision="no" \ + --use_8bit_adam \ + --weighting_scheme="none" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --learning_rate=1e-4 \ + --guidance_scale=1 \ + --report_to="wandb" \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --cache_latents \ + --rank=4 \ + --max_train_steps=700 \ + --seed="0" +``` + +## Inference + +When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example: + +1. First, load the original model and merge the LoRA params into it: + +```py +from diffusers import FluxPipeline +import torch + +ckpt_id = "black-forest-labs/FLUX.1-dev" +pipeline = FluxPipeline.from_pretrained( + ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16 +) +pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors") +pipeline.fuse_lora() +pipeline.unload_lora_weights() + +pipeline.transformer.save_pretrained("fused_transformer") +``` + +2. Quantize the model and run inference + +```py +from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig +import torch + +ckpt_id = "black-forest-labs/FLUX.1-dev" +bnb_4bit_compute_dtype = torch.float16 +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, +) +transformer = FluxTransformer2DModel.from_pretrained( + "fused_transformer", + quantization_config=nf4_config, + torch_dtype=bnb_4bit_compute_dtype, +) +pipeline = AutoPipelineForText2Image.from_pretrained( + ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype +) +pipeline.enable_model_cpu_offload() + +image = pipeline( + "a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768 +).images[0] +image.save("yarn_merged.png") +``` + +| Dequantize, merge, quantize | Merging directly into quantized model | +|-------|-------| +| ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) | + +As we can notice the first column result follows the style more closely. \ No newline at end of file diff --git a/examples/research_projects/flux_lora_quantization/accelerate.yaml b/examples/research_projects/flux_lora_quantization/accelerate.yaml new file mode 100644 index 000000000000..309e13cc140a --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/accelerate.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: NO +downcast_bf16: 'no' +enable_cpu_affinity: true +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/research_projects/flux_lora_quantization/compute_embeddings.py b/examples/research_projects/flux_lora_quantization/compute_embeddings.py new file mode 100644 index 000000000000..8e93af961e65 --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/compute_embeddings.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import pandas as pd +import torch +from datasets import load_dataset +from huggingface_hub.utils import insecure_hashlib +from tqdm.auto import tqdm +from transformers import T5EncoderModel + +from diffusers import FluxPipeline + + +MAX_SEQ_LENGTH = 77 +OUTPUT_PATH = "embeddings.parquet" + + +def generate_image_hash(image): + return insecure_hashlib.sha256(image.tobytes()).hexdigest() + + +def load_flux_dev_pipeline(): + id = "black-forest-labs/FLUX.1-dev" + text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto") + pipeline = FluxPipeline.from_pretrained( + id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced" + ) + return pipeline + + +@torch.no_grad() +def compute_embeddings(pipeline, prompts, max_sequence_length): + all_prompt_embeds = [] + all_pooled_prompt_embeds = [] + all_text_ids = [] + for prompt in tqdm(prompts, desc="Encoding prompts."): + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length) + all_prompt_embeds.append(prompt_embeds) + all_pooled_prompt_embeds.append(pooled_prompt_embeds) + all_text_ids.append(text_ids) + + max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + print(f"Max memory allocated: {max_memory:.3f} GB") + return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids + + +def run(args): + dataset = load_dataset("Norod78/Yarn-art-style", split="train") + image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset} + all_prompts = list(image_prompts.values()) + print(f"{len(all_prompts)=}") + + pipeline = load_flux_dev_pipeline() + all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings( + pipeline, all_prompts, args.max_sequence_length + ) + + data = [] + for i, (image_hash, _) in enumerate(image_prompts.items()): + data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i])) + print(f"{len(data)=}") + + # Create a DataFrame + embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"] + df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols) + print(f"{len(df)=}") + + # Convert embedding lists to arrays (for proper storage in parquet) + for col in embedding_cols: + df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist()) + + # Save the dataframe to a parquet file + df.to_parquet(args.output_path) + print(f"Data successfully serialized to {args.output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--max_sequence_length", + type=int, + default=MAX_SEQ_LENGTH, + help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.", + ) + parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.") + args = parser.parse_args() + + run(args) diff --git a/examples/research_projects/flux_lora_quantization/ds2.yaml b/examples/research_projects/flux_lora_quantization/ds2.yaml new file mode 100644 index 000000000000..beed28fd90ab --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/ds2.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py new file mode 100644 index 000000000000..fd2b5568d6d8 --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py @@ -0,0 +1,1183 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm + +import diffusers +from diffusers import ( + AutoencoderKL, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + pass + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + base_model: str = None, + instance_prompt=None, + repo_folder=None, + quantization_config=None, +): + widget_dict = [] + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). + +Was LoRA for the text encoder enabled? False. + +Quantization config: + +```yaml +{quantization_config} +``` + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## Usage + +TODO + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--quantized_model_path", + type=str, + default=None, + help="Path to the quantized model.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--data_df_path", + type=str, + default=None, + help=("Path to the parquet file serialized with compute_embeddings.py."), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--max_sequence_length", + type=int, + default=77, + help="Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.", + ) + + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora-nf4", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +class DreamBoothDataset(Dataset): + def __init__( + self, + data_df_path, + dataset_name, + size=1024, + max_sequence_length=77, + center_crop=False, + ): + # Logistics + self.size = size + self.center_crop = center_crop + self.max_sequence_length = max_sequence_length + + self.data_df_path = Path(data_df_path) + if not self.data_df_path.exists(): + raise ValueError("`data_df_path` doesn't exists.") + + # Load images. + dataset = load_dataset(dataset_name, split="train") + instance_images = [sample["image"] for sample in dataset] + image_hashes = [self.generate_image_hash(image) for image in instance_images] + self.instance_images = instance_images + self.image_hashes = image_hashes + + # Image transformations + self.pixel_values = self.apply_image_transformations( + instance_images=instance_images, size=size, center_crop=center_crop + ) + + # Map hashes to embeddings. + self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path) + + self.num_instance_images = len(instance_images) + self._length = self.num_instance_images + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + image_hash = self.image_hashes[index % self.num_instance_images] + prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash] + example["instance_images"] = instance_image + example["prompt_embeds"] = prompt_embeds + example["pooled_prompt_embeds"] = pooled_prompt_embeds + example["text_ids"] = text_ids + return example + + def apply_image_transformations(self, instance_images, size, center_crop): + pixel_values = [] + + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + pixel_values.append(image) + + return pixel_values + + def convert_to_torch_tensor(self, embeddings: list): + prompt_embeds = embeddings[0] + pooled_prompt_embeds = embeddings[1] + text_ids = embeddings[2] + prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, 4096) + pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(768) + text_ids = np.array(text_ids).reshape(77, 3) + return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids) + + def map_image_hash_embedding(self, data_df_path): + hashes_df = pd.read_parquet(data_df_path) + data_dict = {} + for i, row in hashes_df.iterrows(): + embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"], row["text_ids"]] + prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor(embeddings=embeddings) + data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds, text_ids)}) + return data_dict + + def generate_image_hash(self, image): + return insecure_hashlib.sha256(image.tobytes()).hexdigest() + + +def collate_fn(examples): + pixel_values = [example["instance_images"] for example in examples] + prompt_embeds = [example["prompt_embeds"] for example in examples] + pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples] + text_ids = [example["text_ids"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + prompt_embeds = torch.stack(prompt_embeds) + pooled_prompt_embeds = torch.stack(pooled_prompt_embeds) + text_ids = torch.stack(text_ids)[0] # just 2D tensor + + batch = { + "pixel_values": pixel_values, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "text_ids": text_ids, + } + return batch + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + bnb_4bit_compute_dtype = torch.float32 + if args.mixed_precision == "fp16": + bnb_4bit_compute_dtype = torch.float16 + elif args.mixed_precision == "bf16": + bnb_4bit_compute_dtype = torch.bfloat16 + if args.quantized_model_path is not None: + transformer = FluxTransformer2DModel.from_pretrained( + args.quantized_model_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=bnb_4bit_compute_dtype, + ) + else: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=nf4_config, + torch_dtype=bnb_4bit_compute_dtype, + ) + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=None, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + if args.quantized_model_path is not None: + transformer_ = FluxTransformer2DModel.from_pretrained( + args.quantized_model_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=bnb_4bit_compute_dtype, + ) + else: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + ) + transformer_ = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=nf4_config, + torch_dtype=bnb_4bit_compute_dtype, + ) + transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + data_df_path=args.data_df_path, + dataset_name="Norod78/Yarn-art-style", + size=args.resolution, + max_sequence_length=args.max_sequence_length, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + del vae + free_memory() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora-nf4" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + with accelerator.accumulate(models_to_accumulate): + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + + latent_image_ids = FluxPipeline._prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2], + model_input.shape[3], + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxPipeline._pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise + prompt_embeds = batch["prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype) + text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype) + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxPipeline._unpack_latents( + model_pred, + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=None, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + base_model=args.pretrained_model_name_or_path, + instance_prompt=None, + repo_folder=args.output_dir, + quantization_config=transformer.config["quantization_config"], + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args)