Skip to content

Conversation

@Isalia20
Copy link

Fixes #2945

@Isalia20
Copy link
Author

@BenjaminBossan Would be glad if you could review it when you get a chance

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR. I currently don't have much time to review and will be OoO next week. so hopefully @githubnemo can take over.

Just my first observations:

  1. From the original issue, I think we concluded that we would rather need a set_base_model_state_dict. That doesn't mean that get_base_model_state_dict doesn't have it's merits, but it wouldn't fully solve the issue. Ping @dvmazur.
  2. There can be deeper nesting of .base_layer, so it should run in a loop: while ".base_layer" in new_key: ....
  3. This doesn't take into account trainable tokens yet, they need to be treated similarly as modules_to_save.

@dvmazur
Copy link

dvmazur commented Jan 17, 2026

Hi! Thanks for this PR! Yeah, I rather need a set_base_model_state_dict, but it should be pretty easy to implement once we have a get_base_model_state_dict I think. Also, maybe we should expand the test matrix to make sure this method works for other PEFTs?

@Isalia20
Copy link
Author

Hi, I'll add the set method as well and more tests little later today

@Isalia20
Copy link
Author

Added the set base state dict and more tests

@Isalia20
Copy link
Author

@githubnemo Would be glad if you could review this :)

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Isalia20 :) Thanks for taking this on.

I need a bit of clarification (possibly from @dvmazur): do I understand correctly that one use-case is that we have a model that doesn't fit in memory so we need to first shard the empty model (via FSDP) onto several devices and then read the checkpoint onto the shards (in a streaming manner)? Furthermore, do I understand correctly that it is not possible to shard the base model and then apply PEFT on top of that? If that are the reasons for why this is useful, we should probably document that as well since it is not obvious.

The implementation seems to pass at first glance but there might be a few pitfalls still. I left one comment regarding a potential bug.

Let's build a test (e.g., a merge of test_get_base_model_state_dict_keys_match and test_get_base_model_state_dict_values_match) and integrate it into tests/testing_common.py (similar to _test_save_pretrained) to be called from the more exhaustive testing suites in tests/test_decoder_models.py, tests/test_encoder_decoder_models.py and tests/test_custom_models.py which cover a lot more cases. For example, trainable tokens and parameter targeting are not covered by the current tests and there are probably a lot more special cases, so leveraging the existing tests is probably best.

Comment on lines +1771 to +1777
for prefix in adapter_prefixes:
if f".{prefix}" in peft_key or peft_key.startswith(prefix):
is_adapter_param = True
break

if is_adapter_param:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not a sufficient filter for methods like VeRA or VB-LoRA that employ weight sharing. This will be covered by the extended tests I suppose.

An alternative approach would be to iterate over all named modules of the model and remove those keys that belong to BaseTunerLayer instances (since the weight shared keys are caught by the prefix matching already in place). But lets see what the tests say first, maybe I'm wrong and everything works fine :)

@dvmazur
Copy link

dvmazur commented Jan 21, 2026

Hi!

I want to be able to load the base model's and adapter's state_dicts after wrapping the PEFT model in FSDP. The state_dict's keys match the original base model's keys, so I need a function that will map the wrapped model's keys to the original ones if that makes sense.

@Isalia20
Copy link
Author

Thanks for the comments. I'll take a look little later this week

@githubnemo
Copy link
Collaborator

Hey @dvmazur,

I want to be able to load the base model's and adapter's state_dicts after wrapping the PEFT model in FSDP.

I got that but why? What's your motivation? My question supposed that memory is a constraint and that's the reason but you didn't acknowledge nor refute that. Please give a bit more detail so that I can understand the use-case better. Thanks!

@dvmazur
Copy link

dvmazur commented Jan 21, 2026

The end goal is to have PEFT working for TorchTitan basically. Titan wraps models into FSDP to save VRAM, it also allocates GPU memory only after the model's meta-device weights were FSDP-sharded.

I think this pseudocode snippet should give you enough info, but feel free to ask if you need any more info:

with torch.device("meta"):
    # can't load base model weights here as it is on meta device before resharding
    model = AutoModelForCausalLM.from_pretrained(...)
    # can only wrap model in peft before fsdp-sharding it
    model = get_perft_model(model, ...)

model = fsdp_shard_model(model)

# actually allocate memory for the model's weights
# state dict can be loaded after that
model.to_empty(device=init_device)

# this function loads a state dict with the original model's module keys
# so I need a way to map them to the PEFT-wrapped model
load_base_model_state_dict(model)
initialize_adapters(model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Return base model state_dict with original keys

4 participants