Skip to content

Fix: try aligning dtype of matrixes when training with deepspeed and mixed-precision is set to bf16 or fp16#2060

Merged
kohya-ss merged 10 commits intokohya-ss:sd3from
saibit-tech:sd3
May 1, 2025
Merged

Fix: try aligning dtype of matrixes when training with deepspeed and mixed-precision is set to bf16 or fp16#2060
kohya-ss merged 10 commits intokohya-ss:sd3from
saibit-tech:sd3

Conversation

@sharlynxy
Copy link
Contributor

@sharlynxy sharlynxy commented Apr 22, 2025

What problem is going to solve in this PR?

This PR is mainly trying fixing the problem subscribed in issue #1871
When I tried to do some training with scriptflux_train.py, I met the same error as the issue above.
When I removed --deepspeed and run with low train_batch_size which makes training slow.

Solution
I tried adding a wrapper in deepspeed_utils.py to wrap models' forward function with torch.autocast which provides convenience methods for mixed precision.

Changes detailed

  1. Add __warp_with_torch_autocast to class DeepSpeedWrapper in deepspeed_utils.py.
  2. Add deepspeed==0.16.7 requirements.
  3. Do nothing when accelerator.distributed_type == DistributedType.DEEPSPEED in function patch_accelerator_for_fp16_training of script train_util.py because deepspeed internally handles loss scaling for mixed precision training then accelerator.scaler would be None which results in the same error as issue 476

After these changes, the dtype error disappeared and train_batch_size increased from 2(without deepspeed) to 12(with deepspeed and mixed-precision) running on 8x Nvidia A100 GPUs(80GB memory each) and get 17.54% speeding up with command as follow:

accelerate launch \
  --num_cpu_threads_per_process=8 \
  --multi_gpu \
  --mixed_precision=fp16 \
  --rdzv_backend=c10d \
  "flux_train.py" \
  --output_dir="output" \
  --logging_dir="logs" \
  --max_train_epochs=60 \
  --learning_rate=2e-5 \
  --output_name=flux_test \
  --save_every_n_epochs=10 \
  --save_precision=fp16 \
  --seed=4242 \
  --max_token_length=225 \
  --caption_extension=.txt \
  --vae_batch_size=4 \
  --deepspeed \
  --zero_stage=3 \
  --ddp_timeout=120 \
  --ddp_gradient_as_bucket_view \
  --ddp_static_graph \
  --mem_eff_save \
  --clip_l="model/clip/ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors" \
  --t5xxl="model/clip/t5xxl_fp16.safetensors" \
  --apply_t5_attn_mask \
  --discrete_flow_shift=3.185 \
  --timestep_sampling=flux_shift \
  --sigmoid_scale=1 \
  --model_prediction_type=raw \
  --guidance_scale=1 \
  --ae="model/flux/ae.safetensors" \
  --cache_text_encoder_outputs \
  --cache_text_encoder_outputs_to_disk \
  --sdpa \
  --train_data_dir="data" \
  --train_batch_size=12 \
  --resolution=1024,1024 \
  --enable_bucket \
  --min_bucket_reso=256 \
  --max_bucket_reso=2048 \
  --bucket_no_upscale \
  --pretrained_model_name_or_path="model/flux/flux1-dev.safetensors" \
  --save_model_as=safetensors \
  --clip_skip=2 \
  --persistent_data_loader_workers \
  --cache_latents \
  --cache_latents_to_disk \
  --gradient_checkpointing \
  --use_8bit_adam \
  --keep_tokens=1 \
  --keep_tokens_separator="|||" \
  --secondary_separator=";;;" \
  --sample_every_n_epochs=200 \
  --sample_sampler=euler_a \
  --full_fp16 \
  --mixed_precision=fp16 \
  --gradient_accumulation_steps=1 \
  --lr_scheduler=warmup_stable_decay \
  --lr_scheduler_num_cycles=1 \
  --lr_decay_steps=0.25 \
  --lr_scheduler_min_lr_ratio=0.1

sharlynxy and others added 6 commits April 22, 2025 16:06
…ry aligning precision when using mixed precision in training process
…:saibit-tech/sd-scripts into dev/xy/align_dtype_using_mixed_precision
…_precision

Fix: try aligning dtype of matrixes when training with deepspeed and mixed-precision is set to bf16 or fp16
Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't personally use DeepSpeed ​​but this PR looks good to me. I would appreciate it if you could check the comments.

@rockerBOO
Copy link
Contributor

My questions are of the requirements. Where was deepspeed coming from before? Is updating to 2.6.0 and having diffusers automatically update it a good idea with various different ways you need to install torch for backend compatibility? Is the diffusers[torch] extra holding this back? Seems like it would be aligned in the same way with the torch version being the same version. Some environments are still only supporting torch 2.4 still so moving to the latest (2.6.0) as a requirement might cause some issues.

@sharlynxy
Copy link
Contributor Author

My questions are of the requirements. Where was deepspeed coming from before? Is updating to 2.6.0 and having diffusers automatically update it a good idea with various different ways you need to install torch for backend compatibility? Is the diffusers[torch] extra holding this back? Seems like it would be aligned in the same way with the torch version being the same version. Some environments are still only supporting torch 2.4 still so moving to the latest (2.6.0) as a requirement might cause some issues.

  1. I followed the installation steps with README of branch sd3, but there is no deepspeed to be installed with commands
pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124
pip3 install -r requirements.txt # in branch sd3

I have tried installing deepspeed using accelerate[deepspeed]==0.33.0 in requirements.txt but there is no deepspeed being detected when launching training

2025-04-23 16:40:55 ERROR    deepspeed is not installed. please install deepspeed in your deepspeed_utils.py:77
                             environment with following command. DS_BUILD_OPS=0 pip                            
                             install deepspeed  

Then I add deepspeed==0.16.7 to requirements.txt, which works well for this PR and have no conflict to existing requirements.

  1. It seems my previous comment have caused some misunderstanding. My concern was that using diffusers[torch]==0.25.0 might automatically upgrade an existing torch 2.4.0 installation to the latest version. I’ve encountered this situation once I tried this repo with another environment, that's why I changed it to diffusers==0.25.0. However, since I can’t reproduce the issue now, I’ll change it back to diffusers[torch]==0.25.0.

* get device type from model

* add logger warning

* format

* format

* format
@sharlynxy sharlynxy requested a review from kohya-ss April 24, 2025 02:41
Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. I'd like to confirm requirements.txt.

@kohya-ss kohya-ss merged commit 7c075a9 into kohya-ss:sd3 May 1, 2025
1 of 2 checks passed
@kohya-ss
Copy link
Owner

kohya-ss commented May 1, 2025

Thank you for update!

@yuweifanf
Copy link

yuweifanf commented Jul 11, 2025

Hi, but when I turn to mixed_precision=bf16, it still arises the [mat1 and mat2 must have the same dtype, but got Float and BFloat16] error. I am running script "flux_train_control_net.py" and the command is

accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py --pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 --optimizer_type adamw8bit --learning_rate 2e-5 --highvram --max_train_epochs 1 --save_every_n_steps 1000 --output_dir /path/to/output --output_name flux-cn --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed --dataset_config dataset.toml --log_tracker_name "sd-scripts-flux-cn"

In addition, the machine only supports CUDA12.2, so I download pytorch2.4.0 with cu121 channel. Is it possible to cause that problem?
pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu121

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants