Skip to content

accelerate.load_checkpoint_and_dispatch does not load GPT-OSS Models correctly #3882

@KyleMylonakisProtopia

Description

@KyleMylonakisProtopia

System Info

- `Accelerate` version: 1.11.0
- Platform: Linux-6.8.0-87-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /home/kyle/llama3/.venv/bin/accelerate
- Python version: 3.10.18
- Numpy version: 2.2.6
- PyTorch version: 2.8.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 2015.50 GB
- GPU type: NVIDIA H100 80GB HBM3
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

These two scripts differ in the weights they load

Run the following script with CUDA_VISIBLE_DEVICES=0 python <script_name>

import itertools

import torch
import transformers


def main() -> None:
    model = transformers.AutoModelForCausalLM.from_pretrained("/models/gpt-oss-20b", dtype=torch.bfloat16)

    uninitialized_tensors: list[str] = [
        name for name, tensor in itertools.chain(model.named_parameters(), model.named_buffers()) if (tensor == 0.0).all()
    ]
    names_str = ", ".join(uninitialized_tensors)
    print(f"The following tensors are all zeros and were likely not properly initialized: {names_str}")


if __name__ == "__main__":
    main()

Run the following script with CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 <script_name>

# © 2022-2025 Protopia AI, Inc. All rights reserved.

from __future__ import annotations

import itertools
import logging
import os
from typing import Final

import accelerate
import torch
import torch.distributed as dist
import transformers

_MODEL_PATH: Final[str] = "/models/gpt-oss-20b"
_TORCH_DTYPE: Final[torch.dtype] = torch.bfloat16
_ZERO_TENSOR_MSG: Final[str] = "The following tensors are all zeros and were likely not properly initialized: %s"


logger = logging.getLogger(__name__)


def _configure_logging() -> None:
    if logging.getLogger().hasHandlers():
        return
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
    )


def _init_distributed() -> int:
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    # torch.cuda.set_device(local_rank)
    return local_rank


def _collect_zero_tensors(model: torch.nn.Module) -> list[str]:
    zero_names: list[str] = []
    with torch.no_grad():
        for name, tensor in itertools.chain(model.named_parameters(), model.named_buffers()):
            if tensor.numel() == 0:
                continue
            if torch.count_nonzero(tensor).item() == 0:
                zero_names.append(name)
    return zero_names


def _gather_zero_tensors(local_zero_tensors: list[str]) -> list[str]:
    if not dist.is_initialized():
        return local_zero_tensors
    gathered: list[list[str]] = [None] * dist.get_world_size()  # type: ignore[list-item]
    dist.all_gather_object(gathered, local_zero_tensors)
    combined: list[str] = []
    for zero_list in gathered:
        combined.extend(zero_list)
    return combined


def main() -> None:
    """Check for uninitialized tensors across distributed ranks."""

    _configure_logging()
    local_rank = _init_distributed()
    device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")

    logger.info("Rank %d loading model on %s", dist.get_rank(), device)

    with accelerate.init_empty_weights():
        model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_PATH, torch_dtype=_TORCH_DTYPE)
    model._apply(
            lambda t: torch.empty_like(t, device="cuda") if t.device == torch.device("meta") else t.to("cuda")
        )

    accelerate.load_checkpoint_and_dispatch(
        model,
        _MODEL_PATH,
        dtype=_TORCH_DTYPE,
        broadcast_from_rank0=True,
    )

    zero_tensors = _collect_zero_tensors(model)
    gathered_zero_tensors = sorted(set(_gather_zero_tensors(zero_tensors)))

    if dist.get_rank() == 0:
        if gathered_zero_tensors:
            logger.warning(_ZERO_TENSOR_MSG, ", ".join(gathered_zero_tensors))
        else:
            logger.info("All tensors are initialized across all ranks.")

    dist.barrier()
    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Expected behavior

These two scripts are loading a GPT-OSS model, however one is loading the model with accelerate. The model loaded with accelerate has the experts up and down projections initialized to zero.

This occurs because the weights saved in the state dict have different names form the modules in the Hugging Face transformes defined class due to the quantizatized form of the distributed model. The line of code which handles this is line 5424 in transformers/modeling_utils.py, which is not being called by accelerate.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions