[Bug] Update torch.optim.Optimizer parameter states after tensor parallelism#3835
Conversation
2. Fix DTensor broadcast issues in cpu_ram_efficient_loading
|
Hello @naomili0924 Thanks for helping me out. With the reproduction setup I shared in #3820, I see no runtime error. This is good news. |
That makes sense, if TP=1 then tensor parallelism won't happen and full_param is a Tensor not a DTensor (that's why to_local() fails). |
33f11e1 to
d4f3e0f
Compare
|
@SunMarc @zach-huggingface Please take a look of this small bug fix and feel free to leave any feedback. Thanks! |
| new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False)) | ||
| # Build a map from old to new params | ||
| mapping = {p: new_named_params[n] for n, p in old_named_params.items()} | ||
|
|
||
| def _get_tensor_address(p): | ||
| if isinstance(p, DTensor): | ||
| return p._local_tensor.data_ptr() | ||
| return p.data_ptr() | ||
|
|
||
| for obj in result: | ||
| if isinstance(obj, torch.optim.Optimizer): | ||
| for param_group in obj.param_groups: | ||
| param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]] | ||
|
|
There was a problem hiding this comment.
please just add a comment on why we do that here
| if isinstance(full_param, DTensor): | ||
| full_param = full_param.to_local() |
There was a problem hiding this comment.
here also can you add a comment
c58f8c3 to
6505e36
Compare
#3820 use this issue as the test cases to guard this bug fix. thanks! |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks a lot for fixing those !
|
@SunMarc Hi, just to check in do I need to do anything at this time to merge it? Thanks! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ab4f7a2 to
f94f986
Compare
|
@SunMarc Hi, sorry to take your time. I'm confused about those two failing run-tests. |
|
Not sure why the CI is struggling. I will merge it ! |
|
Can you check why stderr: [rank0]: Traceback (most recent call last):
stderr: [rank0]: File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 299, in <module>
stderr: [rank0]: main()
stderr: [rank0]: File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 295, in main
stderr: [rank0]: training_function(config, args)
stderr: [rank0]: File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 145, in training_function
stderr: [rank0]: model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
stderr: [rank0]: File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1541, in prepare
stderr: [rank0]: args = self._prepare_tp(*args)
stderr: [rank0]: File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1634, in _prepare_tp
stderr: [rank0]: param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
stderr: [rank0]: File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1634, in <listcomp>
stderr: [rank0]: param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
stderr: [rank0]: KeyError: 139761680384000 |
#3845 I resolved this failed test case in this pull request. please take a look, thanks! |
| old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True)) | ||
|
|
There was a problem hiding this comment.
didin't see this during review but we shouldn't put fsdp related code here if possible. If we have too, it should be in a condition.
| new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False)) | ||
| # Build a map from old to new params | ||
| mapping = {p: new_named_params[n] for n, p in old_named_params.items()} | ||
|
|
||
| def _get_tensor_address(p): | ||
| if isinstance(p, DTensor): | ||
| return p._local_tensor.data_ptr() | ||
| return p.data_ptr() | ||
|
|
||
| for obj in result: | ||
| if isinstance(obj, torch.optim.Optimizer): | ||
| for param_group in obj.param_groups: | ||
| # Each param_group originally maps to model parameters (e.g., from model.parameters()). | ||
| # After _prepare_tp(), parameter references are replaced with DTensor instances. | ||
| # Therefore, we remap the parameter references to their new DTensor addresses | ||
| # so that the optimizer can correctly update the model parameters. | ||
| param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]] |
There was a problem hiding this comment.
Also we are already modifying the optimizer if fsdpv2 is activated in _prepare_fsdpv2, so we shouldn't modify it here if it is enabled.


This PR adds two functions:
This PR aims to fix this issue
#3820
#3821
=========================================================================
Below are reproduction process:
Reproduced this issue based on its original yaml and python file given in the issue description:
accelerate launch --config_file hessian_toolkit/configs/fsdp2_single_node.yaml minimal_reproduce_accelerate_bug.pyHowever to fully fix this issue we need two parts:
Why we need this change:
In AutoModelForCausalLM, the model’s lm_head (an nn.Linear layer with shape [embedding_size, vocab_size]) is automatically added during from_pretrained.
By design, lm_head shares the same weight reference as embed_tokens due to the symmetric input–output projection setup.
Before applying Tensor Parallelism (TP), model.named_parameters() only includes embed_tokens and the backbone parameters — the lm_head is not listed separately.
However, when TP is applied, it explicitly separates embed_tokens and lm_head by assigning them independent DTensor shards. This effectively modifies the model architecture by introducing a distinct lm_head module.
Without wrapping the model with an LmHeadWrapper, a KeyError occurs because the post-TP model’s architecture no longer matches the original state dictionary structure.
With this two change, I run the script and got:

What does this PR do?
Fixes # (issue)
#3820
#3821
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.