Skip to content

Dreambooth SDXL LoRA - mat1 and mat2 shapes cannot be multiplied (2x2048 and 2816x1280) #7239

@shawnrushefsky

Description

@shawnrushefsky

Describe the bug

Trying to run an sdxl lora dreambooth training job with prior preservation. After class images are generated, it dies with the logged error, mat1 and mat2 shapes cannot be multiplied (2x2048 and 2816x1280). Instance images are 37 photos of my dog in a variety of sizes

Reproduction

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
  --instance_data_dir=/instance_images \
  --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
  --output_dir=/output \
  --instance_prompt="timber" \
  --mixed_precision=fp16 \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-05 \
  --lr_scheduler=constant \
  --lr_warmup_steps=0 \
  --checkpointing_steps=100 \
  --seed=0 \
  --resume_from_checkpoint=latest \
  --checkpoints_total_limit=1 \
  --max_train_steps=1800 \
  --train_text_encoder \
  --text_encoder_lr=5e-06 \
  --with_prior_preservation \
  --class_data_dir=/class_images \
  --class_prompt="photo of a dog" \
  --num_class_images=25 \
  --validation_prompt="timber as an ace space pilot, detailed illustration" \
  --validation_epochs=10 \
  --report_to=wandb \
  --sample_batch_size=4

Logs

Steps:   0%|          | 0/1800 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/app/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1793, in <module>
    main(args)
  File "/app/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1572, in main
    model_pred = unet(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 817, in forward
    return model_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 805, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
File "/app/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1162, in forward
    aug_emb = self.get_aug_embed(
  File "/app/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 987, in get_aug_embed
    aug_emb = self.add_embedding(add_embeds)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/diffusers/src/diffusers/models/embeddings.py", line 228, in forward
    sample = self.linear_1(sample)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2048 and 2816x1280)

System Info

  • diffusers version: 0.27.0.dev0
  • Platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.2.0 (True)
  • Huggingface_hub version: 0.21.3
  • Transformers version: 4.38.2
  • Accelerate version: 0.27.0
  • xFormers version: 0.0.24
  • Using GPU in script?: yes, RTX 4090
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions