|  | 
|  | 1 | +# DreamBooth training example for SANA | 
|  | 2 | + | 
|  | 3 | +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject. | 
|  | 4 | + | 
|  | 5 | +The `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://arxiv.org/abs/2410.10629).  | 
|  | 6 | + | 
|  | 7 | + | 
|  | 8 | +This will also allow us to push the trained model parameters to the Hugging Face Hub platform. | 
|  | 9 | + | 
|  | 10 | +## Running locally with PyTorch | 
|  | 11 | + | 
|  | 12 | +### Installing the dependencies | 
|  | 13 | + | 
|  | 14 | +Before running the scripts, make sure to install the library's training dependencies: | 
|  | 15 | + | 
|  | 16 | +**Important** | 
|  | 17 | + | 
|  | 18 | +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | 
|  | 19 | + | 
|  | 20 | +```bash | 
|  | 21 | +git clone https://github.com/huggingface/diffusers | 
|  | 22 | +cd diffusers | 
|  | 23 | +pip install -e . | 
|  | 24 | +``` | 
|  | 25 | + | 
|  | 26 | +Then cd in the `examples/dreambooth` folder and run | 
|  | 27 | +```bash | 
|  | 28 | +pip install -r requirements_sana.txt | 
|  | 29 | +``` | 
|  | 30 | + | 
|  | 31 | +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | 
|  | 32 | + | 
|  | 33 | +```bash | 
|  | 34 | +accelerate config | 
|  | 35 | +``` | 
|  | 36 | + | 
|  | 37 | +Or for a default accelerate configuration without answering questions about your environment | 
|  | 38 | + | 
|  | 39 | +```bash | 
|  | 40 | +accelerate config default | 
|  | 41 | +``` | 
|  | 42 | + | 
|  | 43 | +Or if your environment doesn't support an interactive shell (e.g., a notebook) | 
|  | 44 | + | 
|  | 45 | +```python | 
|  | 46 | +from accelerate.utils import write_basic_config | 
|  | 47 | +write_basic_config() | 
|  | 48 | +``` | 
|  | 49 | + | 
|  | 50 | +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. | 
|  | 51 | +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment. | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +### Dog toy example | 
|  | 55 | + | 
|  | 56 | +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. | 
|  | 57 | + | 
|  | 58 | +Let's first download it locally: | 
|  | 59 | + | 
|  | 60 | +```python | 
|  | 61 | +from huggingface_hub import snapshot_download | 
|  | 62 | + | 
|  | 63 | +local_dir = "./dog" | 
|  | 64 | +snapshot_download( | 
|  | 65 | +    "diffusers/dog-example", | 
|  | 66 | +    local_dir=local_dir, repo_type="dataset", | 
|  | 67 | +    ignore_patterns=".gitattributes", | 
|  | 68 | +) | 
|  | 69 | +``` | 
|  | 70 | + | 
|  | 71 | +This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. | 
|  | 72 | + | 
|  | 73 | +Now, we can launch training using: | 
|  | 74 | + | 
|  | 75 | +```bash | 
|  | 76 | +export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_diffusers" | 
|  | 77 | +export INSTANCE_DIR="dog" | 
|  | 78 | +export OUTPUT_DIR="trained-sana-lora" | 
|  | 79 | + | 
|  | 80 | +accelerate launch train_dreambooth_lora_sana.py \ | 
|  | 81 | +  --pretrained_model_name_or_path=$MODEL_NAME  \ | 
|  | 82 | +  --instance_data_dir=$INSTANCE_DIR \ | 
|  | 83 | +  --output_dir=$OUTPUT_DIR \ | 
|  | 84 | +  --mixed_precision="bf16" \ | 
|  | 85 | +  --instance_prompt="a photo of sks dog" \ | 
|  | 86 | +  --resolution=1024 \ | 
|  | 87 | +  --train_batch_size=1 \ | 
|  | 88 | +  --gradient_accumulation_steps=4 \ | 
|  | 89 | +  --use_8bit_adam \ | 
|  | 90 | +  --learning_rate=1e-4 \ | 
|  | 91 | +  --report_to="wandb" \ | 
|  | 92 | +  --lr_scheduler="constant" \ | 
|  | 93 | +  --lr_warmup_steps=0 \ | 
|  | 94 | +  --max_train_steps=500 \ | 
|  | 95 | +  --validation_prompt="A photo of sks dog in a bucket" \ | 
|  | 96 | +  --validation_epochs=25 \ | 
|  | 97 | +  --seed="0" \ | 
|  | 98 | +  --push_to_hub | 
|  | 99 | +``` | 
|  | 100 | + | 
|  | 101 | +For using `push_to_hub`, make you're logged into your Hugging Face account: | 
|  | 102 | + | 
|  | 103 | +```bash | 
|  | 104 | +huggingface-cli login | 
|  | 105 | +``` | 
|  | 106 | + | 
|  | 107 | +To better track our training experiments, we're using the following flags in the command above: | 
|  | 108 | + | 
|  | 109 | +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before. | 
|  | 110 | +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. | 
|  | 111 | + | 
|  | 112 | +Additionally, we welcome you to explore the following CLI arguments: | 
|  | 113 | + | 
|  | 114 | +* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only. | 
|  | 115 | +* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55). | 
|  | 116 | +* `--max_sequence_length`: Maximum sequence length to use for text embeddings. | 
|  | 117 | + | 
|  | 118 | + | 
|  | 119 | +We provide several options for optimizing memory optimization: | 
|  | 120 | + | 
|  | 121 | +* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used. | 
|  | 122 | +* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. | 
|  | 123 | +* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. | 
0 commit comments