-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When I perform post-training on Qwen-Image using diffusers/examples/dreambooth/train_dreambooth_lora_qwen_image.py, I encounter the error: 'torch.device' object is not iterable. The difference from the official example script is that I added the --bnb_quantization_config_path parameter.
{
"load_in_4bit": true,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": "bfloat16"
}
Reproduction
training script
accelerate launch train_dreambooth_lora_qwen_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
**--bnb_quantization_config_path=$BNB_QUANTIZATION_PATH \**
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--learning_rate=2e-4 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--rank=4 \
--lora_alpha=8 \
--cache_latents \
--offload \
--seed="0"
quantization_config
{
"load_in_4bit": true,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": "bfloat16"
}
Logs
08/15/2025 19:22:57 - INFO - __main__ - Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: bf16
All model checkpoint weights were used when initializing AutoencoderKLQwenImage.
All the weights of AutoencoderKLQwenImage were initialized from the model checkpoint at /root/.cache/modelscope/hub/models/Qwen/Qwen-Image.
If your task is similar to the task the model of the checkpoint was trained on, you can already use AutoencoderKLQwenImage for predictions without further training.
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:00<00:00, 53.59it/s]
The device_map was not initialized. Setting device_map to {: {current_device}}. If you want to use the model for inference, please set device_map ='auto'
Instantiating QwenImageTransformer2DModel model under default dtype torch.bfloat16.
The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 9/9 [00:05<00:00, 1.62it/s]
All model checkpoint weights were used when initializing QwenImageTransformer2DModel.
All the weights of QwenImageTransformer2DModel were initialized from the model checkpoint at /root/.cache/modelscope/hub/models/Qwen/Qwen-Image.
If your task is similar to the task the model of the checkpoint was trained on, you can already use QwenImageTransformer2DModel for predictions without further training.
Loading pipeline components...: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:00<00:00, 12945.38it/s]
Traceback (most recent call last):
File "/root/diffusers/examples/dreambooth/train_dreambooth_lora_qwen_image.py", line 1687, in <module>
main(args)
File "/root/diffusers/examples/dreambooth/train_dreambooth_lora_qwen_image.py", line 1279, in main
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/diffusers/lib/python3.12/contextlib.py", line 144, in __exit__
next(self.gen)
File "/root/diffusers/src/diffusers/training_utils.py", line 352, in offload_models
for m, orig_dev in zip(modules, original_devices):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'torch.device' object is not iterable
Traceback (most recent call last):
File "/root/miniconda3/envs/diffusers/bin/accelerate", line 8, in <module>
sys.exit(main())
^^^^^^
File "/root/miniconda3/envs/diffusers/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
args.func(args)
File "/root/miniconda3/envs/diffusers/lib/python3.12/site-packages/accelerate/commands/launch.py", line 1235, in launch_command
simple_launcher(args)
File "/root/miniconda3/envs/diffusers/lib/python3.12/site-packages/accelerate/commands/launch.py", line 823, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)System Info
`- π€ Diffusers version: 0.35.0.dev0
- Platform: Linux-5.10.134-19.1.al8.x86_64-x86_64-with-glibc2.32
- Running on Google Colab?: No
- Python version: 3.12.11
- PyTorch version (GPU?): 2.8.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.4
- Transformers version: 4.55.2
- Accelerate version: 1.10.0
- PEFT version: 0.17.0
- Bitsandbytes version: 0.47.0
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA L20, 49140 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?: `
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working