-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add get base model state dict #3000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@BenjaminBossan Would be glad if you could review it when you get a chance |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR. I currently don't have much time to review and will be OoO next week. so hopefully @githubnemo can take over.
Just my first observations:
- From the original issue, I think we concluded that we would rather need a
set_base_model_state_dict. That doesn't mean thatget_base_model_state_dictdoesn't have it's merits, but it wouldn't fully solve the issue. Ping @dvmazur. - There can be deeper nesting of
.base_layer, so it should run in a loop:while ".base_layer" in new_key: .... - This doesn't take into account trainable tokens yet, they need to be treated similarly as
modules_to_save.
|
Hi! Thanks for this PR! Yeah, I rather need a |
|
Hi, I'll add the set method as well and more tests little later today |
|
Added the set base state dict and more tests |
|
@githubnemo Would be glad if you could review this :) |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Isalia20 :) Thanks for taking this on.
I need a bit of clarification (possibly from @dvmazur): do I understand correctly that one use-case is that we have a model that doesn't fit in memory so we need to first shard the empty model (via FSDP) onto several devices and then read the checkpoint onto the shards (in a streaming manner)? Furthermore, do I understand correctly that it is not possible to shard the base model and then apply PEFT on top of that? If that are the reasons for why this is useful, we should probably document that as well since it is not obvious.
The implementation seems to pass at first glance but there might be a few pitfalls still. I left one comment regarding a potential bug.
Let's build a test (e.g., a merge of test_get_base_model_state_dict_keys_match and test_get_base_model_state_dict_values_match) and integrate it into tests/testing_common.py (similar to _test_save_pretrained) to be called from the more exhaustive testing suites in tests/test_decoder_models.py, tests/test_encoder_decoder_models.py and tests/test_custom_models.py which cover a lot more cases. For example, trainable tokens and parameter targeting are not covered by the current tests and there are probably a lot more special cases, so leveraging the existing tests is probably best.
| for prefix in adapter_prefixes: | ||
| if f".{prefix}" in peft_key or peft_key.startswith(prefix): | ||
| is_adapter_param = True | ||
| break | ||
|
|
||
| if is_adapter_param: | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not a sufficient filter for methods like VeRA or VB-LoRA that employ weight sharing. This will be covered by the extended tests I suppose.
An alternative approach would be to iterate over all named modules of the model and remove those keys that belong to BaseTunerLayer instances (since the weight shared keys are caught by the prefix matching already in place). But lets see what the tests say first, maybe I'm wrong and everything works fine :)
|
Hi! I want to be able to load the base model's and adapter's state_dicts after wrapping the PEFT model in FSDP. The state_dict's keys match the original base model's keys, so I need a function that will map the wrapped model's keys to the original ones if that makes sense. |
|
Thanks for the comments. I'll take a look little later this week |
|
Hey @dvmazur,
I got that but why? What's your motivation? My question supposed that memory is a constraint and that's the reason but you didn't acknowledge nor refute that. Please give a bit more detail so that I can understand the use-case better. Thanks! |
|
The end goal is to have PEFT working for TorchTitan basically. Titan wraps models into FSDP to save VRAM, it also allocates GPU memory only after the model's meta-device weights were FSDP-sharded. I think this pseudocode snippet should give you enough info, but feel free to ask if you need any more info: with torch.device("meta"):
# can't load base model weights here as it is on meta device before resharding
model = AutoModelForCausalLM.from_pretrained(...)
# can only wrap model in peft before fsdp-sharding it
model = get_perft_model(model, ...)
model = fsdp_shard_model(model)
# actually allocate memory for the model's weights
# state dict can be loaded after that
model.to_empty(device=init_device)
# this function loads a state dict with the original model's module keys
# so I need a way to map them to the PEFT-wrapped model
load_base_model_state_dict(model)
initialize_adapters(model) |
Fixes #2945