Skip to content

Crash when loading Flux Schnell 1 model with train_dreambooth_lora_fluxΒ #11045

@rleygonie

Description

@rleygonie

Describe the bug

When using the Diffusers/example/dreambooth/train_dreambooth_lora_flux script with the Flux Schnell 1 model, the process consistently crashes during the transformer shard loading at 33% (1/3), causing my entire Google JupyterLab kernel to crash.

Question: Is this related to using the Flux Schnell model instead of a Dev model? Is there a known incompatibility?

Logs: 03/12/2025 14:14:26 - INFO - main - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

You set add_prefix_space. The tokenizer needs to be converted from the slow tokenizers
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'use_karras_sigmas', 'shift_terminal', 'use_beta_sigmas', 'time_shift_type', 'invert_sigmas', 'use_exponential_sigmas'} was not found in config. Values will be initialized to default values.

Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 1/2 [00:13<00:13, 13.01s/it]
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:25<00:00, 12.53s/it]
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:25<00:00, 12.60s/it]
Instantiating AutoencoderKL model under default dtype torch.float32.
All model checkpoint weights were used when initializing AutoencoderKL.

All the weights of AutoencoderKL were initialized from the model checkpoint at /home/jupyter/flux_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use AutoencoderKL for predictions without further training.
Instantiating FluxTransformer2DModel model under default dtype torch.float32.
{'out_channels', 'axes_dims_rope'} was not found in config. Values will be initialized to default values.

Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]
Loading checkpoint shards: 33%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 1/3 [00:26<00:52, 26.10s/it]

Reproduction

export MODEL_NAME="black-forest-labs/FLUX.1-schnell"
export INSTANCE_DIR="images"
export OUTPUT_DIR="output"

accelerate launch train_dreambooth_flux.py
--pretrained_model_name_or_path=$MODEL_NAME
--instance_data_dir=$INSTANCE_DIR
--output_dir=$OUTPUT_DIR
--mixed_precision="bf16"
--instance_prompt="a photo of sks dog"
--resolution=512
--train_batch_size=1
--guidance_scale=1
--gradient_accumulation_steps=4
--optimizer="prodigy"
--learning_rate=1.
--report_to="wandb"
--lr_scheduler="constant"
--lr_warmup_steps=0
--max_train_steps=500
--validation_prompt="A photo of sks dog in a bucket"
--validation_epochs=25
--seed="0"

Logs

System Info

  • πŸ€— Diffusers version: 0.33.0.dev0
  • Platform: Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.10.16
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.3
  • Transformers version: 4.49.0
  • Accelerate version: 1.4.0
  • PEFT version: 0.14.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA L4, 23034 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions