Skip to content

Commit 8a50fd0

Browse files
committed
update
1 parent f94ab73 commit 8a50fd0

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

tests/tests_fabric/plugins/precision/test_fsdp.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415
from unittest.mock import Mock
1516

1617
import pytest
@@ -32,13 +33,14 @@
3233
],
3334
)
3435
def test_fsdp_precision_config(precision, expected, expect_warn):
35-
if expect_warn:
36-
with pytest.warns(UserWarning, match="FSDPPrecision.*runs computations in reduced precision"):
37-
plugin = FSDPPrecision(precision=precision)
38-
else:
39-
# No warning should be raised
36+
with warnings.catch_warnings(record=True) as w:
37+
warnings.simplefilter("always") # capture all warnings
4038
plugin = FSDPPrecision(precision=precision)
4139

40+
# Check if the warning was (or wasn’t) logged
41+
has_warn = any("FSDPPrecision" in str(warning.message) for warning in w)
42+
assert has_warn == expect_warn, f"Unexpected warning state for {precision}"
43+
4244
config = plugin.mixed_precision_config
4345

4446
assert config.param_dtype == expected[0]

tests/tests_pytorch/plugins/precision/test_fsdp.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415
from unittest.mock import ANY, MagicMock, Mock
1516

1617
import pytest
@@ -32,13 +33,14 @@
3233
],
3334
)
3435
def test_fsdp_precision_config(precision, expected, expect_warn):
35-
if expect_warn:
36-
with pytest.warns(UserWarning, match="FSDPPrecision.*runs computations in reduced precision"):
37-
plugin = FSDPPrecision(precision=precision)
38-
else:
39-
# No warning should be raised
36+
with warnings.catch_warnings(record=True) as w:
37+
warnings.simplefilter("always") # capture all warnings
4038
plugin = FSDPPrecision(precision=precision)
4139

40+
# Check if the warning was (or wasn’t) logged
41+
has_warn = any("FSDPPrecision" in str(warning.message) for warning in w)
42+
assert has_warn == expect_warn, f"Unexpected warning state for {precision}"
43+
4244
config = plugin.mixed_precision_config
4345

4446
assert config.param_dtype == expected[0]

0 commit comments

Comments
 (0)