From 45f50b47693ae21f35af1dd7c9f93ffbc3bafcce Mon Sep 17 00:00:00 2001 From: Haga Device Date: Wed, 4 Jun 2025 15:34:12 +0200 Subject: [PATCH 1/3] Make sure MPS is used when chosen as accelerator in Fabric --- src/lightning/fabric/connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 0e0e86ee7c63e..55b4af2728e6f 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda" return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") From 66136593d1d724e682c3c935aaa2e6c81d203eb0 Mon Sep 17 00:00:00 2001 From: Haga Device Date: Thu, 5 Jun 2025 17:08:05 +0200 Subject: [PATCH 2/3] Added mps tests to connector and Fabric --- tests/tests_fabric/test_connector.py | 6 ++++++ tests/tests_fabric/test_fabric.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 9bb9fa1d7d145..3d4113fcc356c 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -404,6 +404,12 @@ def test_unsupported_strategy_types_on_cpu_and_fallback(): connector = _Connector(accelerator="cpu", strategy="dp", devices=2) assert isinstance(connector.strategy, DDPStrategy) +@RunIf(mps=True) +@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) +def test_mps_enabled_with_float16_or_bfloat16_precision(precision): + connector = _Connector(accelerator="mps", precision=precision) + assert connector.precision.device == "mps" + def test_invalid_accelerator_choice(): with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='cocofruit'`"): diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index ee002b5d8061c..49b9a68a8b18a 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from contextlib import nullcontext from re import escape from unittest import mock @@ -734,6 +735,20 @@ def test_autocast(): fabric._precision.forward_context().__enter__.assert_called() fabric._precision.forward_context().__exit__.assert_called() +@RunIf(mps=True) +@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) +def test_autocast_does_not_use_cuda_on_mps(precision): + """Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision.""" + fabric = Fabric(accelerator="mps", precision=precision) + fabric.launch() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with fabric.autocast(): + pass + + for warning in w: + assert "device_type of 'cuda'" not in str(warning.message) def test_no_backward_sync(): """Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible.""" From 26d15adbf4a58d6f724f92ad192c9f08cd3b91ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Jun 2025 15:14:08 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/test_connector.py | 1 + tests/tests_fabric/test_fabric.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 3d4113fcc356c..c6bef5943a30f 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -404,6 +404,7 @@ def test_unsupported_strategy_types_on_cpu_and_fallback(): connector = _Connector(accelerator="cpu", strategy="dp", devices=2) assert isinstance(connector.strategy, DDPStrategy) + @RunIf(mps=True) @pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) def test_mps_enabled_with_float16_or_bfloat16_precision(precision): diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 49b9a68a8b18a..dc0203dc067e3 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -735,6 +735,7 @@ def test_autocast(): fabric._precision.forward_context().__enter__.assert_called() fabric._precision.forward_context().__exit__.assert_called() + @RunIf(mps=True) @pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) def test_autocast_does_not_use_cuda_on_mps(precision): @@ -750,6 +751,7 @@ def test_autocast_does_not_use_cuda_on_mps(precision): for warning in w: assert "device_type of 'cuda'" not in str(warning.message) + def test_no_backward_sync(): """Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible.""" fabric = Fabric(devices=1)