-
Notifications
You must be signed in to change notification settings - Fork 1.3k
[Bug] Update torch.optim.Optimizer parameter states after tensor parallelism #3835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1276fcb
d4f3e0f
6505e36
4f08f96
f94f986
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1590,6 +1590,8 @@ def _prepare_tp(self, *args): | |
|
|
||
| device_mesh = self.torch_device_mesh | ||
|
|
||
| old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True)) | ||
|
|
||
| for arg in result: | ||
| if not isinstance(arg, torch.nn.Module): | ||
| continue | ||
|
|
@@ -1613,6 +1615,24 @@ def _prepare_tp(self, *args): | |
| dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad) | ||
| setattr(module_to_tp, param_type, dp) | ||
|
|
||
| new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False)) | ||
| # Build a map from old to new params | ||
| mapping = {p: new_named_params[n] for n, p in old_named_params.items()} | ||
|
|
||
| def _get_tensor_address(p): | ||
| if isinstance(p, DTensor): | ||
| return p._local_tensor.data_ptr() | ||
| return p.data_ptr() | ||
|
|
||
| for obj in result: | ||
| if isinstance(obj, torch.optim.Optimizer): | ||
| for param_group in obj.param_groups: | ||
| # Each param_group originally maps to model parameters (e.g., from model.parameters()). | ||
| # After _prepare_tp(), parameter references are replaced with DTensor instances. | ||
| # Therefore, we remap the parameter references to their new DTensor addresses | ||
| # so that the optimizer can correctly update the model parameters. | ||
| param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]] | ||
|
Comment on lines
+1618
to
+1634
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also we are already modifying the optimizer if fsdpv2 is activated in _prepare_fsdpv2, so we shouldn't modify it here if it is enabled. |
||
|
|
||
|
Comment on lines
1618
to
1635
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please just add a comment on why we do that here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added. |
||
| return args | ||
|
|
||
| def _prepare_cp(self, *args): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -470,7 +470,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic | |
| full_sd (`dict`): The full state dict to load, can only be on rank 0 | ||
| """ | ||
| import torch.distributed as dist | ||
| from torch.distributed.tensor import distribute_tensor | ||
| from torch.distributed.tensor import DTensor, distribute_tensor | ||
|
|
||
| # Model was previously copied to meta device | ||
| meta_sharded_sd = model.state_dict() | ||
|
|
@@ -506,6 +506,11 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): | |
| for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()): | ||
| device_mesh = sharded_param.device_mesh | ||
| full_param = full_param.detach().to(device_mesh.device_type) | ||
| if isinstance(full_param, DTensor): | ||
| # dist.broadcast() only supports torch.Tensor. | ||
| # After prepare_tp(), model parameters may become DTensor. | ||
| # To broadcast such a parameter, convert it to a local tensor first. | ||
| full_param = full_param.to_local() | ||
|
Comment on lines
509
to
513
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here also can you add a comment
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added. |
||
| dist.broadcast(full_param, src=0, group=dist.group.WORLD) | ||
| sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements) | ||
| to_contiguous, casting_dtype = _infer_parameter_dtype( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # Copyright 2022 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from datetime import timedelta | ||
|
|
||
| import torch | ||
| from datasets import load_dataset | ||
| from torch.utils.data import DataLoader | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from accelerate import Accelerator, InitProcessGroupKwargs | ||
| from accelerate.parallelism_config import ParallelismConfig | ||
| from accelerate.utils import FullyShardedDataParallelPlugin | ||
|
|
||
|
|
||
| class LmHeadWrapper(torch.nn.Module): | ||
| def __init__(self, lm_head): | ||
| super().__init__() | ||
| self.lm_head = lm_head | ||
|
|
||
| def forward(self, x): | ||
| return self.lm_head(x) | ||
|
|
||
|
|
||
| def build_simple_dataloader(tokenizer, seq_len=64, batch_size=2): | ||
| """Build a simple dataloader for reproduction.""" | ||
| # Load small dataset | ||
| raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") | ||
| raw = raw.filter(lambda x: len(tokenizer(x["text"])["input_ids"]) > 0) | ||
| raw = raw.select(range(min(100, len(raw)))) # Use only 100 samples | ||
|
|
||
| def tok_fn(examples): | ||
| return tokenizer(examples["text"], truncation=True, max_length=seq_len) | ||
|
|
||
| ds = raw.map(tok_fn, batched=True, remove_columns=["text"]) | ||
| ds.set_format(type="torch", columns=["input_ids"]) | ||
|
|
||
| def collate(batch): | ||
| ids = [b["input_ids"] for b in batch] | ||
| labels = [x.clone() for x in ids] | ||
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id | ||
| x = torch.nn.utils.rnn.pad_sequence(ids, batch_first=True, padding_value=pad_id) | ||
| y = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) | ||
| return {"input_ids": x, "labels": y} | ||
|
|
||
| return DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate) | ||
|
|
||
|
|
||
| def main(): | ||
| # Configuration | ||
| MODEL_NAME = "Qwen/Qwen3-0.6B" | ||
| BATCH_SIZE = 2 | ||
| SEQ_LEN = 64 | ||
| TP = 2 | ||
| DP = 4 // TP | ||
|
|
||
| # Setup Accelerator with FSDP2 | ||
| init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800)) | ||
| pc = ParallelismConfig(dp_shard_size=DP, tp_size=TP) | ||
|
|
||
| fsdp_plugin = FullyShardedDataParallelPlugin( | ||
| fsdp_version=2, | ||
| reshard_after_forward=True, | ||
| auto_wrap_policy="transformer_based_wrap", | ||
| state_dict_type="SHARDED_STATE_DICT", | ||
| activation_checkpointing=False, | ||
| cpu_ram_efficient_loading=True, | ||
| ) | ||
|
|
||
| accelerator = Accelerator(kwargs_handlers=[init_kwargs], parallelism_config=pc, fsdp_plugin=fsdp_plugin) | ||
|
|
||
| rank = accelerator.process_index | ||
| print(f"[Rank {rank}] Initializing...") | ||
|
|
||
| # Load model with TP if needed | ||
| model_kwargs = {"tp_size": TP, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh} if TP > 1 else {} | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_cache=False, **model_kwargs) | ||
|
|
||
| model.lm_head = LmHeadWrapper(model.lm_head) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
| optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||
|
|
||
| print(f"[Rank {rank}] Building dataloader...") | ||
| loader = build_simple_dataloader(tokenizer, seq_len=SEQ_LEN, batch_size=BATCH_SIZE) | ||
|
|
||
| print(f"[Rank {rank}] Preparing with accelerator...") | ||
| # ERROR OCCURS HERE AT LINE 110 in original script | ||
| model, optimizer, loader = accelerator.prepare(model, optimizer, loader) | ||
|
|
||
| print(f"[Rank {rank}] Preparation successful!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # FSDP2 Single Node Configuration | ||
| # Status: CURRENT - Recommended for new single-node usage | ||
|
|
||
| compute_environment: LOCAL_MACHINE | ||
| debug: false | ||
| distributed_type: FSDP | ||
| downcast_bf16: 'no' | ||
| machine_rank: 0 | ||
| main_training_function: main | ||
| mixed_precision: 'no' | ||
| num_machines: 1 | ||
| num_processes: 4 # Adjust for your GPU count | ||
| rdzv_backend: static | ||
| same_network: true | ||
| tpu_env: [] | ||
| tpu_use_cluster: false | ||
| tpu_use_sudo: false | ||
| use_cpu: false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didin't see this during review but we shouldn't put fsdp related code here if possible. If we have too, it should be in a condition.