Skip to content

Commit 115ee41

Browse files
committed
check
1 parent cba0c48 commit 115ee41

File tree

4 files changed

+36
-32
lines changed

4 files changed

+36
-32
lines changed

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def convert_module(self, module: Module) -> Module:
8585
def mixed_precision_config(self) -> "TorchMixedPrecision":
8686
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
8787

88-
if "true" in self.precision and self.precision != "32-true":
88+
if self.precision in ("16-true", "bf16-true"):
8989
rank_zero_warn(
9090
f"FSDPPrecision `{self.precision}` enables mixed-precision execution. "
9191
"Model parameters remain in full precision `torch.float32`, while forward and backward passes "

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
9595
def mixed_precision_config(self) -> "TorchMixedPrecision":
9696
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
9797

98-
if "true" in self.precision and self.precision != "32-true":
98+
if self.precision in ("16-true", "bf16-true"):
9999
rank_zero_warn(
100100
f"FSDPPrecision `{self.precision}` enables mixed-precision execution. "
101101
"Model parameters remain in full precision `torch.float32`, while forward and backward passes "

tests/tests_fabric/plugins/precision/test_fsdp.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
14+
from contextlib import contextmanager
1515
from unittest.mock import Mock
1616

1717
import pytest
@@ -22,26 +22,28 @@
2222
from tests_fabric.helpers.runif import RunIf
2323

2424

25+
@contextmanager
26+
def null_ctx_manager(*args, **kwargs):
27+
yield
28+
29+
2530
@pytest.mark.parametrize(
26-
("precision", "expected", "expect_warn"),
31+
("precision", "expected"),
2732
[
28-
("16-true", (torch.float16, torch.float16, torch.float16), True),
29-
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
30-
("16-mixed", (torch.float16, torch.float16, torch.float16), True),
31-
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
32-
("32-true", (torch.float32, torch.float32, torch.float32), False),
33+
("16-true", (torch.float16, torch.float16, torch.float16)),
34+
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
35+
("16-mixed", (torch.float16, torch.float16, torch.float16)),
36+
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
37+
("32-true", (torch.float32, torch.float32, torch.float32)),
3338
],
3439
)
35-
def test_fsdp_precision_config(precision, expected, expect_warn):
36-
with warnings.catch_warnings(record=True) as w:
37-
warnings.simplefilter("always") # capture all warnings
38-
plugin = FSDPPrecision(precision=precision)
40+
def test_fsdp_precision_config(precision, expected):
41+
plugin = FSDPPrecision(precision=precision)
3942

40-
# Check if the warning was (or wasn’t) logged
41-
has_warn = any("FSDPPrecision" in str(warning.message) for warning in w)
42-
assert has_warn == expect_warn, f"Unexpected warning state for {precision}"
43+
warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx_manager
4344

44-
config = plugin.mixed_precision_config
45+
with warning_ctx(UserWarning, match="enables mixed-precision execution"):
46+
config = plugin.mixed_precision_config
4547

4648
assert config.param_dtype == expected[0]
4749
assert config.buffer_dtype == expected[1]

tests/tests_pytorch/plugins/precision/test_fsdp.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
14+
from contextlib import contextmanager
1515
from unittest.mock import ANY, MagicMock, Mock
1616

1717
import pytest
@@ -22,26 +22,28 @@
2222
from tests_pytorch.helpers.runif import RunIf
2323

2424

25+
@contextmanager
26+
def null_ctx_manager(*args, **kwargs):
27+
yield
28+
29+
2530
@pytest.mark.parametrize(
26-
("precision", "expected", "expect_warn"),
31+
("precision", "expected"),
2732
[
28-
("16-true", (torch.float16, torch.float16, torch.float16), True),
29-
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
30-
("16-mixed", (torch.float16, torch.float16, torch.float16), True),
31-
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
32-
("32-true", (torch.float32, torch.float32, torch.float32), False),
33+
("16-true", (torch.float16, torch.float16, torch.float16)),
34+
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
35+
("16-mixed", (torch.float16, torch.float16, torch.float16)),
36+
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
37+
("32-true", (torch.float32, torch.float32, torch.float32)),
3338
],
3439
)
35-
def test_fsdp_precision_config(precision, expected, expect_warn):
36-
with warnings.catch_warnings(record=True) as w:
37-
warnings.simplefilter("always") # capture all warnings
38-
plugin = FSDPPrecision(precision=precision)
40+
def test_fsdp_precision_config(precision, expected):
41+
plugin = FSDPPrecision(precision=precision)
3942

40-
# Check if the warning was (or wasn’t) logged
41-
has_warn = any("FSDPPrecision" in str(warning.message) for warning in w)
42-
assert has_warn == expect_warn, f"Unexpected warning state for {precision}"
43+
warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx_manager
4344

44-
config = plugin.mixed_precision_config
45+
with warning_ctx(UserWarning, match="enables mixed-precision execution"):
46+
config = plugin.mixed_precision_config
4547

4648
assert config.param_dtype == expected[0]
4749
assert config.buffer_dtype == expected[1]

0 commit comments

Comments
 (0)