Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,8 @@ def _prepare_tp(self, *args):

device_mesh = self.torch_device_mesh

old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))

Comment on lines +1593 to +1594
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didin't see this during review but we shouldn't put fsdp related code here if possible. If we have too, it should be in a condition.

for arg in result:
if not isinstance(arg, torch.nn.Module):
continue
Expand All @@ -1613,6 +1615,24 @@ def _prepare_tp(self, *args):
dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)
setattr(module_to_tp, param_type, dp)

new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False))
# Build a map from old to new params
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}

def _get_tensor_address(p):
if isinstance(p, DTensor):
return p._local_tensor.data_ptr()
return p.data_ptr()

for obj in result:
if isinstance(obj, torch.optim.Optimizer):
for param_group in obj.param_groups:
# Each param_group originally maps to model parameters (e.g., from model.parameters()).
# After _prepare_tp(), parameter references are replaced with DTensor instances.
# Therefore, we remap the parameter references to their new DTensor addresses
# so that the optimizer can correctly update the model parameters.
param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
Comment on lines +1618 to +1634
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we are already modifying the optimizer if fsdpv2 is activated in _prepare_fsdpv2, so we shouldn't modify it here if it is enabled.


Comment on lines 1618 to 1635
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please just add a comment on why we do that here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

return args

def _prepare_cp(self, *args):
Expand Down
7 changes: 6 additions & 1 deletion src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
from torch.distributed.tensor import DTensor, distribute_tensor

# Model was previously copied to meta device
meta_sharded_sd = model.state_dict()
Expand Down Expand Up @@ -506,6 +506,11 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
device_mesh = sharded_param.device_mesh
full_param = full_param.detach().to(device_mesh.device_type)
if isinstance(full_param, DTensor):
# dist.broadcast() only supports torch.Tensor.
# After prepare_tp(), model parameters may become DTensor.
# To broadcast such a parameter, convert it to a local tensor first.
full_param = full_param.to_local()
Comment on lines 509 to 513
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also can you add a comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

dist.broadcast(full_param, src=0, group=dist.group.WORLD)
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
Expand Down
107 changes: 107 additions & 0 deletions tests/tp/fsdp2_tp_preparation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import timedelta

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import FullyShardedDataParallelPlugin


class LmHeadWrapper(torch.nn.Module):
def __init__(self, lm_head):
super().__init__()
self.lm_head = lm_head

def forward(self, x):
return self.lm_head(x)


def build_simple_dataloader(tokenizer, seq_len=64, batch_size=2):
"""Build a simple dataloader for reproduction."""
# Load small dataset
raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
raw = raw.filter(lambda x: len(tokenizer(x["text"])["input_ids"]) > 0)
raw = raw.select(range(min(100, len(raw)))) # Use only 100 samples

def tok_fn(examples):
return tokenizer(examples["text"], truncation=True, max_length=seq_len)

ds = raw.map(tok_fn, batched=True, remove_columns=["text"])
ds.set_format(type="torch", columns=["input_ids"])

def collate(batch):
ids = [b["input_ids"] for b in batch]
labels = [x.clone() for x in ids]
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
x = torch.nn.utils.rnn.pad_sequence(ids, batch_first=True, padding_value=pad_id)
y = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
return {"input_ids": x, "labels": y}

return DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)


def main():
# Configuration
MODEL_NAME = "Qwen/Qwen3-0.6B"
BATCH_SIZE = 2
SEQ_LEN = 64
TP = 2
DP = 4 // TP

# Setup Accelerator with FSDP2
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
pc = ParallelismConfig(dp_shard_size=DP, tp_size=TP)

fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
reshard_after_forward=True,
auto_wrap_policy="transformer_based_wrap",
state_dict_type="SHARDED_STATE_DICT",
activation_checkpointing=False,
cpu_ram_efficient_loading=True,
)

accelerator = Accelerator(kwargs_handlers=[init_kwargs], parallelism_config=pc, fsdp_plugin=fsdp_plugin)

rank = accelerator.process_index
print(f"[Rank {rank}] Initializing...")

# Load model with TP if needed
model_kwargs = {"tp_size": TP, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh} if TP > 1 else {}

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_cache=False, **model_kwargs)

model.lm_head = LmHeadWrapper(model.lm_head)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

print(f"[Rank {rank}] Building dataloader...")
loader = build_simple_dataloader(tokenizer, seq_len=SEQ_LEN, batch_size=BATCH_SIZE)

print(f"[Rank {rank}] Preparing with accelerator...")
# ERROR OCCURS HERE AT LINE 110 in original script
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)

print(f"[Rank {rank}] Preparation successful!")


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions tests/tp/fsdp2_tp_preparation_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# FSDP2 Single Node Configuration
# Status: CURRENT - Recommended for new single-node usage

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4 # Adjust for your GPU count
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
16 changes: 16 additions & 0 deletions tests/tp/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


import os

from accelerate.test_utils.testing import (
TempDirTestCase,
execute_subprocess_async,
Expand Down Expand Up @@ -61,3 +63,17 @@ def test_working_of_tp(self):
)
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)

def test_working_of_tp_and_fsdp(self):
current_dir = os.path.dirname(os.path.abspath(__file__))
self.test_file_path = os.path.join(current_dir, "fsdp2_tp_preparation.py")
self.test_config_path = os.path.join(current_dir, "fsdp2_tp_preparation_config.yaml")
cmd = get_launch_command()
cmd.extend(
[
f"--config_file={self.test_config_path}",
self.test_file_path,
]
)
with patch_environment(omp_num_threads=4):
execute_subprocess_async(cmd)
Loading