Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 18, 2025

What does this PR do?

Training script will follow in an immediate PR. Couldn't keep calm and ended up adding it.

TODOs

  • Training README
  • Training tests (can do in a follow-up PR, too)

Will update a successful wandb run page once available. But I think this is already ready to review now.

WandB: https://wandb.ai/sayakpaul/dreambooth-lumina2-lora/runs/gpyi2qsl?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul changed the title [LoRA] add LoRA support to Lumina2 [LoRA] add LoRA support to Lumina2 and fine-tuning script Feb 19, 2025
@sayakpaul
Copy link
Member Author

@zhuole1025 could you help verify the training?

Command:

Command
export MODEL_NAME="Alpha-VLLM/Lumina-Image-2.0"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-lumina2-lora"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_dreambooth_lora_lumina2.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --use_8bit_adam \
  --learning_rate=1e-5 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0"

Download the "dog" dataset like so:

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)

Major question is do we need SD3 like weighting_scheme ("logit_normal") or do we not need that and we could simply follow what Flux, SANA, etc. do (current state).

Thanks in advance.

@sayakpaul sayakpaul added lora roadmap Add to current release roadmap labels Feb 19, 2025
@zhuole1025
Copy link
Contributor

zhuole1025 commented Feb 19, 2025

@zhuole1025 could you help verify the training?

Command:

Command
Download the "dog" dataset like so:

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)

Major question is do we need SD3 like weighting_scheme ("logit_normal") or do we not need that and we could simply follow what Flux, SANA, etc. do (current state).

Thanks in advance.

Lumina is completely compatible with Flux, i.e., we use logit normal for training and dynamic shift for inference~ However, Lumina task time=1.0 as clean data, which is reserve to flux and sd3.

@sayakpaul
Copy link
Member Author

However, Lumina task time=1.0 as clean data, which is reserve to flux and sd3.

Not sure what you mean. We have the following lerp:

sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

You mean it should be:

sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input

dynamic shift for inference

I guess we don't have to account for that during fine-tuning?

@zhuole1025
Copy link
Contributor

Yes, I mean when sigmas=1.0, noisy_model_input should be model_input. The finetuning schedule should be decided by the user to use uniform, logitnormal, or with shift. (I think this is already supported in diffuers so no need to do this anymore)

@sayakpaul
Copy link
Member Author

Yes, I mean when sigmas=1.0, noisy_model_input should be model_input. The finetuning schedule should be decided by the user to use uniform, logitnormal, or with shift. (I think this is already supported in diffuers so no need to do this anymore)

Thanks @zhuole1025. Could you check the training code and confirm if possible. I don't think we train when shift. But would love to implement that, too. Could you provide some guidance?

@zhuole1025
Copy link
Contributor

Yes, I mean when sigmas=1.0, noisy_model_input should be model_input. The finetuning schedule should be decided by the user to use uniform, logitnormal, or with shift. (I think this is already supported in diffuers so no need to do this anymore)

Thanks @zhuole1025. Could you check the training code and confirm if possible. I don't think we train when shift. But would love to implement that, too. Could you provide some guidance?

I took a rough look and find almost everything seems good, except for our lumina predict data - noise, so the objective should be reversed too.

@sayakpaul
Copy link
Member Author

@zhuole1025 thanks for your help and guidance so far. https://wandb.ai/sayakpaul/dreambooth-lumina2-lora/runs/d7bxrrnm is the latest run:

accelerate launch train_dreambooth_lora_lumina2.py \
  --pretrained_model_name_or_path=Alpha-VLLM/Lumina-Image-2.0  \
  --dataset_name=Norod78/Yarn-art-style --instance_prompt="a puppy, yarn art style" \
  --output_dir=trained-lumina2-lora-yarn \
  --mixed_precision="bf16" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --use_8bit_adam \
  --learning_rate=1e-5 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=700 \
  --final_validation_prompt="a puppy in a pond, yarn art style" \
  --seed="0"

Do the latest changes look good to you? I think I have resolved all the feedback. And if you do a run on your end on any subject you want, that would be awesome!

@sayakpaul sayakpaul requested a review from zhuole1025 February 19, 2025 10:36
@sayakpaul sayakpaul marked this pull request as ready for review February 19, 2025 11:31
@sayakpaul sayakpaul requested review from a-r-r-o-w and removed request for zhuole1025 February 19, 2025 11:36
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sayakpaul sayakpaul merged commit f10d3c6 into main Feb 20, 2025
14 of 15 checks passed
@sayakpaul sayakpaul deleted the lumina2-lora branch February 20, 2025 04:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

lora roadmap Add to current release roadmap

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants