-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
fabriclightning.fabric.Fabriclightning.fabric.FabricfeatureIs an improvement or enhancementIs an improvement or enhancementstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallel
Milestone
Description
Description & Motivation
Currently, very large models can't be instantiated before sharding if they don't fit in CPU memory. The solution to this, and similar to how other libraries like deepspeed do it, is to create fake tensors that don't allocate memory, then shard them, and only materialze them once sharded on each device.
Pitch
Adopt the FakeTensorMode context manager. It is not officially documented, but we can start experimenting it in Fabric.
The usage in PyTorch would be:
with FakeTensorMode():
model = MyModel(...) # tensors are fake and don't allocate memory
This would translate to
with fabric.sharded_model():
model = MyModel(...)
in Fabric.
Alternatives
A similar mechanism exists in torchdistx.
Additional context
Proposed by @justusschock
No response
Metadata
Metadata
Assignees
Labels
fabriclightning.fabric.Fabriclightning.fabric.FabricfeatureIs an improvement or enhancementIs an improvement or enhancementstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallel