Skip to content

AssertionError: found no DeviceMesh from dtensor args for c10d.broadcast_.defaultΒ #3820

@juraev

Description

@juraev

System Info

- `Accelerate` version: 1.11.0
- Platform: Linux-6.8.0-65-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /venv/main/bin/accelerate
- Python version: 3.12.11
- Numpy version: 2.3.4
- PyTorch version: 2.9.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 1007.71 GB
- GPU type: NVIDIA GeForce RTX 4090
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

I am running on a single node with 4 RTX 4090.

My single node config is:

# FSDP2 Single Node Configuration
# Status: CURRENT - Recommended for new single-node usage

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4  # Adjust for your GPU count
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

and the script is:

from datetime import timedelta
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.utils.data import DataLoader
from datasets import load_dataset


def build_simple_dataloader(tokenizer, seq_len=64, batch_size=2):
    """Build a simple dataloader for reproduction."""
    # Load small dataset
    raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
    raw = raw.filter(lambda x: len(tokenizer(x["text"])['input_ids']) > 0)
    raw = raw.select(range(min(100, len(raw))))  # Use only 100 samples

    def tok_fn(examples):
        return tokenizer(examples["text"], truncation=True, max_length=seq_len)

    ds = raw.map(tok_fn, batched=True, remove_columns=["text"])
    ds.set_format(type="torch", columns=["input_ids"])

    def collate(batch):
        ids = [b["input_ids"] for b in batch]
        labels = [x.clone() for x in ids]
        pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        x = torch.nn.utils.rnn.pad_sequence(ids, batch_first=True, padding_value=pad_id)
        y = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
        return {"input_ids": x, "labels": y}

    return DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)


def main():
    # Configuration
    MODEL_NAME = "Qwen/Qwen3-0.6B"
    BATCH_SIZE = 2
    SEQ_LEN = 64
    TP = 2
    DP = 4 // TP

    # Setup Accelerator with FSDP2
    init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
    pc = ParallelismConfig(dp_shard_size=DP, tp_size=TP)

    fsdp_plugin = FullyShardedDataParallelPlugin(
        fsdp_version=2,
        reshard_after_forward=True,
        auto_wrap_policy="transformer_based_wrap",
        state_dict_type="SHARDED_STATE_DICT",
        activation_checkpointing=False,
        cpu_ram_efficient_loading=True,
    )

    accelerator = Accelerator(
        kwargs_handlers=[init_kwargs],
        parallelism_config=pc,
        fsdp_plugin=fsdp_plugin
    )

    rank = accelerator.process_index
    print(f"[Rank {rank}] Initializing...")

    # Load model with TP if needed
    model_kwargs = (
        {"tp_size": TP, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh}
        if TP > 1
        else {}
    )

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        use_cache=False,
        **model_kwargs
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    print(f"[Rank {rank}] Building dataloader...")
    loader = build_simple_dataloader(tokenizer, seq_len=SEQ_LEN, batch_size=BATCH_SIZE)

    print(f"[Rank {rank}] Preparing with accelerator...")
    # ERROR OCCURS HERE AT LINE 110 in original script
    model, optimizer, loader = accelerator.prepare(model, optimizer, loader)

    print(f"[Rank {rank}] Preparation successful!")

    # Optional: Run a simple forward pass to verify everything works
    print(f"[Rank {rank}] Running test forward pass...")
    model.train()
    batch = next(iter(loader))
    outputs = model(**batch)
    print(f"[Rank {rank}] Forward pass successful! Loss: {outputs.loss.item():.4f}")


if __name__ == "__main__":
    main()

and I just do:

accelerate launch --config_file hessian_toolkit/configs/fsdp2_single_node.yaml minimal_reproduce_accelerate_bug.py

And this produces

[Rank 0] Initializing...
[Rank 1] Initializing...
[Rank 2] Initializing...
[Rank 3] Initializing...
[Rank 1] Building dataloader...
[Rank 2] Building dataloader...
[Rank 3] Building dataloader...
[Rank 0] Building dataloader...
[Rank 1] Preparing with accelerator...
[Rank 2] Preparing with accelerator...
[Rank 0] Preparing with accelerator...
[Rank 3] Preparing with accelerator...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/Scaled-Lanczos/minimal_reproduce_accelerate_bug.py", line 99, in <module>
[rank0]:     main()
[rank0]:   File "/workspace/Scaled-Lanczos/minimal_reproduce_accelerate_bug.py", line 86, in main
[rank0]:     model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/accelerate/accelerator.py", line 1555, in prepare
[rank0]:     result = self._prepare_fsdp2(*args)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/accelerate/accelerator.py", line 1687, in _prepare_fsdp2
[rank0]:     model = fsdp2_prepare_model(self, model)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/accelerate/utils/fsdp_utils.py", line 678, in fsdp2_prepare_model
[rank0]:     fsdp2_load_full_state_dict(accelerator, model, original_sd)
[rank0]:   File "/venv/main/lib/python3.12/site-packages/accelerate/utils/fsdp_utils.py", line 509, in fsdp2_load_full_state_dict
[rank0]:     dist.broadcast(full_param, src=0, group=dist.group.WORLD)
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2835, in broadcast
[rank0]:     work = group.broadcast([tensor], opts)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 349, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 152, in dispatch
[rank0]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/venv/main/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 432, in unwrap_to_op_info
[rank0]:     assert compute_mesh is not None, (
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: found no DeviceMesh from dtensor args for c10d.broadcast_.default!

Expected behavior

I expect the above code to run without runtime error.

While my original code has more details in it, I am very confident that the bug or the culprit is located in this minimal example.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions