Skip to content

Commit 14b6c3e

Browse files
ArmannasHaga Devicepre-commit-ci[bot]
authored
Ensure correct device is used for autocast when mps is selected as Fabric accelerator (#20876)
* Make sure MPS is used when chosen as accelerator in Fabric * Added mps tests to connector and Fabric --------- Co-authored-by: Haga Device <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 821611b commit 14b6c3e

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

src/lightning/fabric/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision:
492492
if self._precision_input == "16-mixed"
493493
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494494
)
495-
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
495+
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
496496
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
497497

498498
raise RuntimeError("No precision set")

tests/tests_fabric/test_connector.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
405405
assert isinstance(connector.strategy, DDPStrategy)
406406

407407

408+
@RunIf(mps=True)
409+
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
410+
def test_mps_enabled_with_float16_or_bfloat16_precision(precision):
411+
connector = _Connector(accelerator="mps", precision=precision)
412+
assert connector.precision.device == "mps"
413+
414+
408415
def test_invalid_accelerator_choice():
409416
with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='cocofruit'`"):
410417
_Connector(accelerator="cocofruit")

tests/tests_fabric/test_fabric.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import warnings
1516
from contextlib import nullcontext
1617
from re import escape
1718
from unittest import mock
@@ -735,6 +736,22 @@ def test_autocast():
735736
fabric._precision.forward_context().__exit__.assert_called()
736737

737738

739+
@RunIf(mps=True)
740+
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
741+
def test_autocast_does_not_use_cuda_on_mps(precision):
742+
"""Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision."""
743+
fabric = Fabric(accelerator="mps", precision=precision)
744+
fabric.launch()
745+
746+
with warnings.catch_warnings(record=True) as w:
747+
warnings.simplefilter("always")
748+
with fabric.autocast():
749+
pass
750+
751+
for warning in w:
752+
assert "device_type of 'cuda'" not in str(warning.message)
753+
754+
738755
def test_no_backward_sync():
739756
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
740757
fabric = Fabric(devices=1)

0 commit comments

Comments
 (0)