Skip to content

Commit ab552e1

Browse files
committed
Add convert_module to FSDP
1 parent 5dea36c commit ab552e1

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
7373
}
7474
self._desired_input_dtype = precision_to_type[self.precision]
7575

76+
@override
77+
def convert_module(self, module: Module) -> Module:
78+
if "true" in self.precision:
79+
return module.to(dtype=self._desired_input_dtype)
80+
return module
81+
7682
@property
7783
def mixed_precision_config(self) -> "TorchMixedPrecision":
7884
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
from lightning_utilities import apply_to_collection
1818
from torch import Tensor
19+
from torch.nn import Module
1920
from typing_extensions import get_args, override
2021

2122
import lightning.pytorch as pl
@@ -72,6 +73,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
7273
}
7374
self._desired_input_dtype = precision_to_type[self.precision]
7475

76+
@override
77+
def convert_module(self, module: Module) -> Module:
78+
if "true" in self.precision:
79+
return module.to(dtype=self._desired_input_dtype)
80+
return module
81+
7582
@override
7683
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
7784
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

tests/tests_fabric/plugins/precision/test_fsdp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,21 @@ def test_invalid_precision_with_fsdp_precision():
127127

128128
with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"):
129129
FSDPPrecision(precision="64-true")
130+
131+
132+
@pytest.mark.parametrize(
133+
("precision", "expected_dtype"),
134+
[
135+
("32-true", torch.float32),
136+
("bf16-mixed", torch.float32),
137+
("16-mixed", torch.float32),
138+
("bf16-true", torch.bfloat16),
139+
("16-true", torch.float16),
140+
],
141+
)
142+
def test_convert_module(precision, expected_dtype):
143+
precision = FSDPPrecision(precision=precision)
144+
module = torch.nn.Linear(2, 2)
145+
assert module.weight.dtype == module.bias.dtype == torch.float32
146+
module = precision.convert_module(module)
147+
assert module.weight.dtype == module.bias.dtype == expected_dtype

tests/tests_pytorch/plugins/precision/test_fsdp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ def test_fsdp_precision_config(precision, expected):
4040
assert config.reduce_dtype == expected[2]
4141

4242

43+
@pytest.mark.parametrize(
44+
("precision", "expected_dtype"),
45+
[
46+
("32-true", torch.float32),
47+
("bf16-mixed", torch.float32),
48+
("16-mixed", torch.float32),
49+
("bf16-true", torch.bfloat16),
50+
("16-true", torch.float16),
51+
],
52+
)
53+
def test_convert_module(precision, expected_dtype):
54+
precision = FSDPPrecision(precision=precision)
55+
module = torch.nn.Linear(2, 2)
56+
assert module.weight.dtype == module.bias.dtype == torch.float32
57+
module = precision.convert_module(module)
58+
assert module.weight.dtype == module.bias.dtype == expected_dtype
59+
60+
4361
def test_fsdp_precision_default_scaler():
4462
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
4563

0 commit comments

Comments
 (0)