Skip to content

activation_checkpointing=True uses *a lot* more memory than gradient_checkpointing=True *with FSDP 2.0* (FSDP 1.0 is fine)Β #3897

@d5031

Description

@d5031

Comparing models wrapper in FSDP-1 and FSDP-2 the issue seem to be that FSDP 2.0 wraps all sub-modules of transformer layer (for Qwen in my case, but same happens for other models) and it's certainly done on purpose (some interplay with sharding, @S1ro1 perhaps would know for sure?)

parent_module = model.get_submodule(parent_name) if parent_name else model

    parent_module = model.get_submodule(parent_name) if parent_name else model
    if auto_wrap_policy_func(parent_module):
        layer = checkpoint_wrapper(layer, preserve_rng_state=False)
        parent_module.register_module(child_name, layer)

so we get MLP and attention wrapped separately (as well as layernorms), while in FSDP-1 only whole transformer layer was wrapped (which is much more memory efficient in this case).

Falling back to activation_checkpointing=False and gradient_checkpointing=True fixes the issue for FSDP-2, but is not advisable (@SunMarc ) ? huggingface/transformers#30404

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions