Skip to content

[Bug] prepare() silently double-wraps models and double-serializes checkpoints when called twice on the same object #3967

@iavinas

Description

@iavinas

System Info

- `Accelerate` version: 1.12.0
- Platform: Linux-6.6.113+-x86_64-with-glibc2.35
- `accelerate` bash location: /usr/local/bin/accelerate
- Python version: 3.12.12
- Numpy version: 2.0.2
- PyTorch version: 2.9.0+cu126
- PyTorch accelerator: CUDA
- System RAM: 31.35 GB
- GPU type: Tesla T4
- `Accelerate` default config:
	Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce the behavior:

  1. The minimal snippet that exposes the registry corruption:
    from accelerate import Accelerator
    import torch.nn as nn

    accelerator = Accelerator()
    model = nn.Linear(10, 2)

    model, = accelerator.prepare(model)          # first prepare
    print(type(model))                           # <class 'DistributedDataParallel'>
    print(len(accelerator._models))              # 1

    model, = accelerator.prepare(model)          # second prepare — no error, no warning
    print(type(model))                           # still DistributedDataParallel — looks fine
    print(len(accelerator._models))              # 2  ← silent corruption
  1. The full reproduction script with gradient hooks, checkpoint save/load,
    and DOUBLE_PREPARE toggle is here: https://gist.github.com/iavinas/283b8e3fda92bf94c96a7211d0d3720e

  2. Launch with: accelerate launch --num_processes=2 accelerate_debug_prepare_twice.py

Expected behavior

--- DOUBLE_PREPARE = False (correct baseline) ---

Registry:
TOTAL MODELS: 1
model id: 140633038895728

Module structure (unwrap once to reach original):
DDP → Linear(10, 2)

state_dict keys (correct prefix):
module.weight
module.bias

Checkpoint calls per rank per save_state():
[Rank 0] STATE_DICT CALLED ×1
[Rank 1] STATE_DICT CALLED ×1

Checkpoint calls per rank per load_state():
[Rank 0] STATE_DICT CALLED ×1
[Rank 1] STATE_DICT CALLED ×1

--- DOUBLE_PREPARE = True (buggy) ---

Registry:
TOTAL MODELS: 2
model id: 136223388433872 ← DDP(model) — entry from 1st prepare
model id: 136222951175856 ← DDP(DDP(model)) — entry from 2nd prepare, NEW object

The two distinct IDs confirm a real nested wrapper was constructed,
not a duplicate reference to the same object.

Module structure (unwrap once — stops at inner DDP):
DDP → DDP → Linear(10, 2)

state_dict keys (corrupted prefix):
module.module.weight
module.module.bias

Any code that loads a checkpoint saved by a correctly-prepared run will
fail with a key mismatch:
RuntimeError: Error(s) in loading state_dict:
Missing key(s): "module.weight", "module.bias"
Unexpected key(s): "module.module.weight", "module.module.bias"

Checkpoint calls per rank per save_state():
[Rank 0] STATE_DICT CALLED ×2 ← doubled
[Rank 1] STATE_DICT CALLED ×2

Checkpoint calls per rank per load_state():
[Rank 0] STATE_DICT CALLED ×2 ← doubled
[Rank 1] STATE_DICT CALLED ×2

--- Why this is dangerous ---

Training losses are numerically identical between the two runs.
The outer DDP wrapper's all-reduce is a no-op because the inner wrapper
has already synchronized gradients. There is no crash, no NaN, no visible
training divergence. The only symptoms are:

  • _models registry length grows to 2
  • state_dict key prefix corrupted: "module.module." instead of "module."
  • checkpoint save/load cost doubles for large models
  • cross-run checkpoint compatibility silently broken

--- Expected behavior ---

prepare() should detect that the model argument was already registered
in _models — either as the exact prepared object or as the unwrapped
inner module of an existing entry — emit a UserWarning explaining the
consequences, and return the model unchanged without appending a second
entry to _models.

The identity check must use is (object identity) not == (structural
equality) so that legitimate multi-model setups like knowledge distillation,
where two distinct model objects share the same architecture, are not
affected.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions