Comparing models wrapper in FSDP-1 and FSDP-2 the issue seem to be that FSDP 2.0 wraps all sub-modules of transformer layer (for Qwen in my case, but same happens for other models) and it's certainly done on purpose (some interplay with sharding, @S1ro1 perhaps would know for sure?)
|
parent_module = model.get_submodule(parent_name) if parent_name else model |
parent_module = model.get_submodule(parent_name) if parent_name else model
if auto_wrap_policy_func(parent_module):
layer = checkpoint_wrapper(layer, preserve_rng_state=False)
parent_module.register_module(child_name, layer)
so we get MLP and attention wrapped separately (as well as layernorms), while in FSDP-1 only whole transformer layer was wrapped (which is much more memory efficient in this case).
Falling back to activation_checkpointing=False and gradient_checkpointing=True fixes the issue for FSDP-2, but is not advisable (@SunMarc ) ? huggingface/transformers#30404