-
Notifications
You must be signed in to change notification settings - Fork 1.3k
[Bug] prepare() silently double-wraps models and double-serializes checkpoints when called twice on the same object #3967
Description
System Info
- `Accelerate` version: 1.12.0
- Platform: Linux-6.6.113+-x86_64-with-glibc2.35
- `accelerate` bash location: /usr/local/bin/accelerate
- Python version: 3.12.12
- Numpy version: 2.0.2
- PyTorch version: 2.9.0+cu126
- PyTorch accelerator: CUDA
- System RAM: 31.35 GB
- GPU type: Tesla T4
- `Accelerate` default config:
Not foundInformation
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - My own task or dataset (give details below)
Reproduction
Steps to reproduce the behavior:
- The minimal snippet that exposes the registry corruption:
from accelerate import Accelerator
import torch.nn as nn
accelerator = Accelerator()
model = nn.Linear(10, 2)
model, = accelerator.prepare(model) # first prepare
print(type(model)) # <class 'DistributedDataParallel'>
print(len(accelerator._models)) # 1
model, = accelerator.prepare(model) # second prepare — no error, no warning
print(type(model)) # still DistributedDataParallel — looks fine
print(len(accelerator._models)) # 2 ← silent corruption-
The full reproduction script with gradient hooks, checkpoint save/load,
and DOUBLE_PREPARE toggle is here: https://gist.github.com/iavinas/283b8e3fda92bf94c96a7211d0d3720e -
Launch with: accelerate launch --num_processes=2 accelerate_debug_prepare_twice.py
Expected behavior
--- DOUBLE_PREPARE = False (correct baseline) ---
Registry:
TOTAL MODELS: 1
model id: 140633038895728
Module structure (unwrap once to reach original):
DDP → Linear(10, 2)
state_dict keys (correct prefix):
module.weight
module.bias
Checkpoint calls per rank per save_state():
[Rank 0] STATE_DICT CALLED ×1
[Rank 1] STATE_DICT CALLED ×1
Checkpoint calls per rank per load_state():
[Rank 0] STATE_DICT CALLED ×1
[Rank 1] STATE_DICT CALLED ×1
--- DOUBLE_PREPARE = True (buggy) ---
Registry:
TOTAL MODELS: 2
model id: 136223388433872 ← DDP(model) — entry from 1st prepare
model id: 136222951175856 ← DDP(DDP(model)) — entry from 2nd prepare, NEW object
The two distinct IDs confirm a real nested wrapper was constructed,
not a duplicate reference to the same object.
Module structure (unwrap once — stops at inner DDP):
DDP → DDP → Linear(10, 2)
state_dict keys (corrupted prefix):
module.module.weight
module.module.bias
Any code that loads a checkpoint saved by a correctly-prepared run will
fail with a key mismatch:
RuntimeError: Error(s) in loading state_dict:
Missing key(s): "module.weight", "module.bias"
Unexpected key(s): "module.module.weight", "module.module.bias"
Checkpoint calls per rank per save_state():
[Rank 0] STATE_DICT CALLED ×2 ← doubled
[Rank 1] STATE_DICT CALLED ×2
Checkpoint calls per rank per load_state():
[Rank 0] STATE_DICT CALLED ×2 ← doubled
[Rank 1] STATE_DICT CALLED ×2
--- Why this is dangerous ---
Training losses are numerically identical between the two runs.
The outer DDP wrapper's all-reduce is a no-op because the inner wrapper
has already synchronized gradients. There is no crash, no NaN, no visible
training divergence. The only symptoms are:
- _models registry length grows to 2
- state_dict key prefix corrupted: "module.module." instead of "module."
- checkpoint save/load cost doubles for large models
- cross-run checkpoint compatibility silently broken
--- Expected behavior ---
prepare() should detect that the model argument was already registered
in _models — either as the exact prepared object or as the unwrapped
inner module of an existing entry — emit a UserWarning explaining the
consequences, and return the model unchanged without appending a second
entry to _models.
The identity check must use is (object identity) not == (structural
equality) so that legitimate multi-model setups like knowledge distillation,
where two distinct model objects share the same architecture, are not
affected.