Skip to content

fix crash in optimizer.step when fsdp2 is enabled and model is bfloat16#3905

Merged
SunMarc merged 1 commit intohuggingface:mainfrom
sywangyi:np_parall_trainer_fix
Jan 13, 2026
Merged

fix crash in optimizer.step when fsdp2 is enabled and model is bfloat16#3905
SunMarc merged 1 commit intohuggingface:mainfrom
sywangyi:np_parall_trainer_fix

Conversation

@sywangyi
Copy link
Contributor

@SunMarc
accelerate launch --config-file configs/tp_hsdp.yaml nd_parallel_trainer.py --sequence-length 1024
after apply huggingface/transformers#43226 anther crash happen. seee the backtrace.

[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/disk3/wangyi/accelerate/examples/torch_native_parallelism/nd_parallel_trainer.py", line 82, in
[rank0]: main()
[rank0]: File "/mnt/disk3/wangyi/accelerate/examples/torch_native_parallelism/nd_parallel_trainer.py", line 77, in main
[rank0]: trainer.train()
[rank0]: File "/mnt/disk3/wangyi/transformers/src/transformers/trainer.py", line 2174, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk3/wangyi/transformers/src/transformers/trainer.py", line 2596, in _inner_training_loop
[rank0]: self.optimizer.step()
[rank0]: File "/mnt/disk3/wangyi/accelerate/src/accelerate/optimizer.py", line 179, in step
[rank0]: self.optimizer.step(closure)
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 133, in wrapper
[rank0]: return func.get(opt, opt.class)(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 517, in wrapper
[rank0]: out = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 82, in _use_grad
[rank0]: ret = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/adam.py", line 247, in step
[rank0]: adam(
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 150, in maybe_fallback
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/adam.py", line 956, in adam
[rank0]: func(
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/adam.py", line 830, in _fused_adam
[rank0]: grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 546, in _group_tensors_by_device_and_dtype
[rank0]: return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_foreach_utils.py", line 55, in _group_tensors_by_device_and_dtype
[rank0]: return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Tensors of the same index must be on the same device and the same dtype except step tensors that can be CPU and float32/64 notwithstanding
[rank3]: Traceback (most recent call last):
[rank3]: File "/mnt/disk3/wangyi/accelerate/examples/torch_native_parallelism/nd_parallel_trainer.py", line 82, in
[rank3]: main()
[rank3]: File "/mnt/disk3/wangyi/accelerate/examples/torch_native_parallelism/nd_parallel_trainer.py", line 77, in main
[rank3]: trainer.train()
[rank3]: File "/mnt/disk3/wangyi/transformers/src/transformers/trainer.py", line 2174, in train
[rank3]: return inner_training_loop(
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk3/wangyi/transformers/src/transformers/trainer.py", line 2596, in _inner_training_loop
[rank3]: self.optimizer.step()
[rank3]: File "/mnt/disk3/wangyi/accelerate/src/accelerate/optimizer.py", line 179, in step
[rank3]: self.optimizer.step(closure)
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 133, in wrapper
[rank3]: return func.get(opt, opt.class)(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 517, in wrapper
[rank3]: out = func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 82, in _use_grad
[rank3]: ret = func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/adam.py", line 247, in step
[rank3]: adam(
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 150, in maybe_fallback
[rank3]: return func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/adam.py", line 956, in adam
[rank3]: func(
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/adam.py", line 830, in _fused_adam
[rank3]: grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/optim/optimizer.py", line 546, in _group_tensors_by_device_and_dtype
[rank3]: return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type]
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank3]: return func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_foreach_utils.py", line 55, in _group_tensors_by_device_and_dtype
[rank3]: return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: Tensors of the same index must be on the same device and the same dtype except step tensors that can be CPU and float32/64 notwithstanding
0%|

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, do you know why only upcasting param.data is not enough ?

@sywangyi
Copy link
Contributor Author

seems it's limitation of DTensor. it could not change the dtype of param, you could use following code to test.

import torch
import os
import sys

# Mocking DTensor if possible or minimal check
try:
    from torch.distributed.tensor import DTensor, DeviceMesh, Shard, distribute_tensor
    import torch.distributed as dist

    # We need a distributed environment for DeviceMesh usually
    # Try to setup a dummy one
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"

    if not dist.is_initialized():
        dist.init_process_group("nccl", rank=0, world_size=1)

    device_mesh = DeviceMesh("cuda", [0])

    # Create a tensor
    t = torch.randn(10, 10, dtype=torch.bfloat16)

    # Distribute it
    dt = distribute_tensor(t, device_mesh, [Shard(0)])

    print(f"Original DTensor dtype: {dt.dtype}")
    print(f"Original Local Tensor dtype: {dt.to_local().dtype}")

    # Try the operation in question
    dt_fp32 = dt.to(torch.float32)

    print(f"Converted DTensor dtype: {dt_fp32.dtype}")
    print(f"Converted Local Tensor dtype: {dt_fp32.to_local().dtype}")

    # Mocking Parameter behavior
    p = torch.nn.Parameter(dt)
    print(f"Parameter dtype before: {p.dtype}")

    # Simulate the upcast line
    p.data = p.data.to(torch.float32)

    print(f"Parameter dtype after: {p.dtype}")
    print(f"Parameter data dtype after: {p.data.dtype}")
    print(f"Full Parameter: {p}")

    dist.destroy_process_group()

except ImportError:
    print("Torch distributed or DTensor not available")
except Exception as e:
    import traceback
    traceback.print_exc()
    print(f"Error occurred: {e}")

@SunMarc
Copy link
Member

SunMarc commented Jan 13, 2026

Thanks for the explanation ! Merging

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc merged commit 758f3b1 into huggingface:main Jan 13, 2026
21 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants