File tree Expand file tree Collapse file tree 2 files changed +14
-10
lines changed
tests_fabric/plugins/precision
tests_pytorch/plugins/precision Expand file tree Collapse file tree 2 files changed +14
-10
lines changed Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import warnings
1415from unittest .mock import Mock
1516
1617import pytest
3233 ],
3334)
3435def test_fsdp_precision_config (precision , expected , expect_warn ):
35- if expect_warn :
36- with pytest .warns (UserWarning , match = "FSDPPrecision.*runs computations in reduced precision" ):
37- plugin = FSDPPrecision (precision = precision )
38- else :
39- # No warning should be raised
36+ with warnings .catch_warnings (record = True ) as w :
37+ warnings .simplefilter ("always" ) # capture all warnings
4038 plugin = FSDPPrecision (precision = precision )
4139
40+ # Check if the warning was (or wasn’t) logged
41+ has_warn = any ("FSDPPrecision" in str (warning .message ) for warning in w )
42+ assert has_warn == expect_warn , f"Unexpected warning state for { precision } "
43+
4244 config = plugin .mixed_precision_config
4345
4446 assert config .param_dtype == expected [0 ]
Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import warnings
1415from unittest .mock import ANY , MagicMock , Mock
1516
1617import pytest
3233 ],
3334)
3435def test_fsdp_precision_config (precision , expected , expect_warn ):
35- if expect_warn :
36- with pytest .warns (UserWarning , match = "FSDPPrecision.*runs computations in reduced precision" ):
37- plugin = FSDPPrecision (precision = precision )
38- else :
39- # No warning should be raised
36+ with warnings .catch_warnings (record = True ) as w :
37+ warnings .simplefilter ("always" ) # capture all warnings
4038 plugin = FSDPPrecision (precision = precision )
4139
40+ # Check if the warning was (or wasn’t) logged
41+ has_warn = any ("FSDPPrecision" in str (warning .message ) for warning in w )
42+ assert has_warn == expect_warn , f"Unexpected warning state for { precision } "
43+
4244 config = plugin .mixed_precision_config
4345
4446 assert config .param_dtype == expected [0 ]
You can’t perform that action at this time.
0 commit comments