Skip to content

Adopt FakeTensorMode for FSDP #16448

@awaelchli

Description

@awaelchli

Description & Motivation

Currently, very large models can't be instantiated before sharding if they don't fit in CPU memory. The solution to this, and similar to how other libraries like deepspeed do it, is to create fake tensors that don't allocate memory, then shard them, and only materialze them once sharded on each device.

Pitch

Adopt the FakeTensorMode context manager. It is not officially documented, but we can start experimenting it in Fabric.

The usage in PyTorch would be:

with FakeTensorMode():
    model = MyModel(...)  # tensors are fake and don't allocate memory

This would translate to

with fabric.sharded_model():
    model = MyModel(...)

in Fabric.

Alternatives

A similar mechanism exists in torchdistx.

Additional context

Proposed by @justusschock

No response

cc @Borda @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

Labels

fabriclightning.fabric.FabricfeatureIs an improvement or enhancementstrategy: fsdpFully Sharded Data Parallel

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions