Skip to content

[Bug] Update torch.optim.Optimizer parameter states after tensor parallelism#3835

Merged
SunMarc merged 5 commits intohuggingface:mainfrom
naomili0924:fix_tp_and_fsdp_optimizer
Nov 19, 2025
Merged

[Bug] Update torch.optim.Optimizer parameter states after tensor parallelism#3835
SunMarc merged 5 commits intohuggingface:mainfrom
naomili0924:fix_tp_and_fsdp_optimizer

Conversation

@naomili0924
Copy link
Contributor

@naomili0924 naomili0924 commented Nov 9, 2025

This PR adds two functions:

  1. update torch.optimizer.parameter reference in the self._prepare_cp() function, similar to fsdp_switch_optimizer_parameters
  2. In fsdp stage, if model parameters are DTensor (passing through tensor parallelism previously), convert DTensor to Tensor for dist.broadcast()

This PR aims to fix this issue
#3820
#3821

=========================================================================
Below are reproduction process:

Reproduced this issue based on its original yaml and python file given in the issue description:
accelerate launch --config_file hessian_toolkit/configs/fsdp2_single_node.yaml minimal_reproduce_accelerate_bug.py
However to fully fix this issue we need two parts:

  1. Wrap lm_head as a single and separate module in, please modify the python file as:
class LmHeadWrapper(torch.nn.Module):
    def __init__(self, lm_head):
        super().__init__()
        self.lm_head = lm_head
    def forward(self, x):
        return self.lm_head(x)

model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        use_cache=False,
        **model_kwargs
    )

model.lm_head = LmHeadWrapper(model.lm_head)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

Why we need this change:

In AutoModelForCausalLM, the model’s lm_head (an nn.Linear layer with shape [embedding_size, vocab_size]) is automatically added during from_pretrained.
By design, lm_head shares the same weight reference as embed_tokens due to the symmetric input–output projection setup.

Before applying Tensor Parallelism (TP), model.named_parameters() only includes embed_tokens and the backbone parameters — the lm_head is not listed separately.

However, when TP is applied, it explicitly separates embed_tokens and lm_head by assigning them independent DTensor shards. This effectively modifies the model architecture by introducing a distinct lm_head module.

Without wrapping the model with an LmHeadWrapper, a KeyError occurs because the post-TP model’s architecture no longer matches the original state dictionary structure.

image
  1. After TP, model's tensors were changed to DTensor so the torch.optim.Optimizer param_group must be redirected to the new DTensor location, otherwise, fsdp2_switch_optimizer_parameters will throw Key Error for old addresses.

With this two change, I run the script and got:
image

What does this PR do?

Fixes # (issue)
#3820
#3821

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

2. Fix DTensor broadcast issues in cpu_ram_efficient_loading
@naomili0924 naomili0924 changed the title Fix Torch.Optim.Optimizer param_group in TP+FSDP [Bug] Fix param_group in torch.optim.Optimizer for TP+FSDP Nov 9, 2025
@juraev
Copy link

juraev commented Nov 9, 2025

Hello @naomili0924

Thanks for helping me out.
I have installed accelerate from your fork/branch.

With the reproduction setup I shared in #3820, I see no runtime error. This is good news.
However, I notice it hanged and did not finish the forward pass.
Moreover, when I just set TP=1 on that code, I see the following runtime error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/Scaled-Lanczos/tmp/bug.py", line 109, in <module>
[rank0]:     main()
[rank0]:   File "/workspace/Scaled-Lanczos/tmp/bug.py", line 96, in main
[rank0]:     model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank0]:     result = self._prepare_fsdp2(*args)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank0]:     model = fsdp2_prepare_model(self, model)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/accelerate/src/accelerate/utils/fsdp_utils.py", line 679, in fsdp2_prepare_model
[rank0]:     fsdp2_load_full_state_dict(accelerator, model, original_sd)
[rank0]:   File "/workspace/accelerate/src/accelerate/utils/fsdp_utils.py", line 509, in fsdp2_load_full_state_dict
[rank0]:     full_param = full_param.to_local()
[rank0]:                  ^^^^^^^^^^^^^^^^^^^
[rank0]: AttributeError: 'Tensor' object has no attribute 'to_local'

@naomili0924
Copy link
Contributor Author

naomili0924 commented Nov 9, 2025

to_local

That makes sense, if TP=1 then tensor parallelism won't happen and full_param is a Tensor not a DTensor (that's why to_local() fails).

@naomili0924
Copy link
Contributor Author

model.lm_head = LmHeadWrapper(model.lm_head)

image

I added another commit to avoid Tensor.to_local() if TP doesn't happen.
However, to fully run your code, please remember to comment out
model.lm_head = LmHeadWrapper(model.lm_head)
because when TP=1 the _prepare_tp won't be used.

@naomili0924 naomili0924 force-pushed the fix_tp_and_fsdp_optimizer branch from 33f11e1 to d4f3e0f Compare November 10, 2025 00:22
@naomili0924 naomili0924 changed the title [Bug] Fix param_group in torch.optim.Optimizer for TP+FSDP [Bug] Update torch.optim.Optimizer parameter states after tensor parallelism Nov 10, 2025
@naomili0924
Copy link
Contributor Author

@SunMarc @zach-huggingface Please take a look of this small bug fix and feel free to leave any feedback. Thanks!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, really nice report, it would be great it you can add a simple test for each case !

Comment on lines 1618 to 1631
new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False))
# Build a map from old to new params
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}

def _get_tensor_address(p):
if isinstance(p, DTensor):
return p._local_tensor.data_ptr()
return p.data_ptr()

for obj in result:
if isinstance(obj, torch.optim.Optimizer):
for param_group in obj.param_groups:
param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please just add a comment on why we do that here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

Comment on lines 509 to 510
if isinstance(full_param, DTensor):
full_param = full_param.to_local()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also can you add a comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

@naomili0924 naomili0924 force-pushed the fix_tp_and_fsdp_optimizer branch from c58f8c3 to 6505e36 Compare November 14, 2025 02:47
@naomili0924
Copy link
Contributor Author

Thanks a lot, really nice report, it would be great it you can add a simple test for each case !

#3820 use this issue as the test cases to guard this bug fix. thanks!

@naomili0924 naomili0924 requested a review from SunMarc November 14, 2025 02:59
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing those !

@naomili0924
Copy link
Contributor Author

@SunMarc Hi, just to check in do I need to do anything at this time to merge it? Thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@naomili0924 naomili0924 force-pushed the fix_tp_and_fsdp_optimizer branch from ab4f7a2 to f94f986 Compare November 18, 2025 05:44
@naomili0924
Copy link
Contributor Author

@SunMarc Hi, sorry to take your time. I'm confused about those two failing run-tests.
image
I run those tests locally and they look good. Thanks for your help.

@SunMarc
Copy link
Member

SunMarc commented Nov 19, 2025

Not sure why the CI is struggling. I will merge it !

@SunMarc SunMarc merged commit 00b1b18 into huggingface:main Nov 19, 2025
23 of 25 checks passed
@SunMarc
Copy link
Member

SunMarc commented Nov 21, 2025

Can you check why test_working_of_tp and test_working_of_tp_and_fsdp are failing after your PR ? Here's one of the traceback. I will do a patch to fix this later on

stderr: [rank0]: Traceback (most recent call last):
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 299, in <module>
stderr: [rank0]:     main()
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 295, in main
stderr: [rank0]:     training_function(config, args)
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 145, in training_function
stderr: [rank0]:     model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1541, in prepare
stderr: [rank0]:     args = self._prepare_tp(*args)
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1634, in _prepare_tp
stderr: [rank0]:     param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1634, in <listcomp>
stderr: [rank0]:     param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
stderr: [rank0]: KeyError: 139761680384000

@naomili0924
Copy link
Contributor Author

Can you check why test_working_of_tp and test_working_of_tp_and_fsdp are failing after your PR ? Here's one of the traceback. I will do a patch to fix this later on

stderr: [rank0]: Traceback (most recent call last):
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 299, in <module>
stderr: [rank0]:     main()
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 295, in main
stderr: [rank0]:     training_function(config, args)
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/test_utils/scripts/external_deps/test_performance.py", line 145, in training_function
stderr: [rank0]:     model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1541, in prepare
stderr: [rank0]:     args = self._prepare_tp(*args)
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1634, in _prepare_tp
stderr: [rank0]:     param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
stderr: [rank0]:   File "/__w/accelerate/accelerate/accelerate/src/accelerate/accelerator.py", line 1634, in <listcomp>
stderr: [rank0]:     param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
stderr: [rank0]: KeyError: 139761680384000

#3845 I resolved this failed test case in this pull request. please take a look, thanks!

Comment on lines +1593 to +1594
old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didin't see this during review but we shouldn't put fsdp related code here if possible. If we have too, it should be in a condition.

Comment on lines +1618 to +1634
new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False))
# Build a map from old to new params
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}

def _get_tensor_address(p):
if isinstance(p, DTensor):
return p._local_tensor.data_ptr()
return p.data_ptr()

for obj in result:
if isinstance(obj, torch.optim.Optimizer):
for param_group in obj.param_groups:
# Each param_group originally maps to model parameters (e.g., from model.parameters()).
# After _prepare_tp(), parameter references are replaced with DTensor instances.
# Therefore, we remap the parameter references to their new DTensor addresses
# so that the optimizer can correctly update the model parameters.
param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we are already modifying the optimizer if fsdpv2 is activated in _prepare_fsdpv2, so we shouldn't modify it here if it is enabled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants