Skip to content

DP Replicate Groups and collective reduction with FSDP2 APIsΒ #21059

@nathan-az

Description

@nathan-az

Description & Motivation

Collective Scheduling control

The FSDP2 APIs offer significant flexibility in controlling when collectives are performed. For example, disabling resharding (after forward or backward) between TP groups until the final gradient accumulation step (or even between DP groups to emulate DDP if the model is sufficiently small).

Support for an additional dimension in Data Parallelism/FSDP

As far as I can tell, Lightning doesn't currently support having 2D data sharding groups (i.e. if I have multiple nodes, sharding my model within each node instead of across all workers). Doing so allows the reduction in required collectives, e.g. disabling the gradient all-reduce required (via set_requires_all_reduce and set_is_last_backward) until the final gradient accumulation step. I believe this is equivalent (or close to) FSDP1's HYBRID_SHARD strategy.

The motivation here is that the collectives required for FSDP (model parameter sharding and unsharding) are generally cheaper than those required for TP (sharing activations with potentially very high sequence length). So unless memory even with FSDP is an issue, or global batch size becomes too large, using only FSDP but over a 2D mesh and controlling collectives can offer greater tokens per second.

The above are all particularly helpful when training very large models, and when inter-node network bandwidth is a limiting factor (i.e. synchronising gradients unnecessarily every step is costly).

Apologies if any of the above is implemented and I've missed it.

Pitch

  • Support separation of data_parallel to data_parallel_shard_dim and data_parallel_replicate_dim
  • Support greater flexibility in collective scheduling, especially when using gradient accumulation

Additional context

I believe that all of the above should more-or-less be model agnostic and can sit generically inside the training loop.

cc @lantiga @Borda (apologies I accidentally removed these mentions during an edit)

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementstrategy: fsdpFully Sharded Data Parallel

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions