diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 0e0e86ee7c63e..55b4af2728e6f 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda" return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 9bb9fa1d7d145..c6bef5943a30f 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback(): assert isinstance(connector.strategy, DDPStrategy) +@RunIf(mps=True) +@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) +def test_mps_enabled_with_float16_or_bfloat16_precision(precision): + connector = _Connector(accelerator="mps", precision=precision) + assert connector.precision.device == "mps" + + def test_invalid_accelerator_choice(): with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='cocofruit'`"): _Connector(accelerator="cocofruit") diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index ee002b5d8061c..dc0203dc067e3 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from contextlib import nullcontext from re import escape from unittest import mock @@ -735,6 +736,22 @@ def test_autocast(): fabric._precision.forward_context().__exit__.assert_called() +@RunIf(mps=True) +@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) +def test_autocast_does_not_use_cuda_on_mps(precision): + """Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision.""" + fabric = Fabric(accelerator="mps", precision=precision) + fabric.launch() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with fabric.autocast(): + pass + + for warning in w: + assert "device_type of 'cuda'" not in str(warning.message) + + def test_no_backward_sync(): """Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible.""" fabric = Fabric(devices=1)