File tree Expand file tree Collapse file tree 3 files changed +25
-1
lines changed Expand file tree Collapse file tree 3 files changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision:
492
492
if self ._precision_input == "16-mixed"
493
493
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494
494
)
495
- device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
495
+ device = self . _accelerator_flag if self ._accelerator_flag in ( "cpu" , "mps" ) else "cuda"
496
496
return MixedPrecision (precision = self ._precision_input , device = device ) # type: ignore[arg-type]
497
497
498
498
raise RuntimeError ("No precision set" )
Original file line number Diff line number Diff line change @@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
405
405
assert isinstance (connector .strategy , DDPStrategy )
406
406
407
407
408
+ @RunIf (mps = True )
409
+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
410
+ def test_mps_enabled_with_float16_or_bfloat16_precision (precision ):
411
+ connector = _Connector (accelerator = "mps" , precision = precision )
412
+ assert connector .precision .device == "mps"
413
+
414
+
408
415
def test_invalid_accelerator_choice ():
409
416
with pytest .raises (ValueError , match = "You selected an invalid accelerator name: `accelerator='cocofruit'`" ):
410
417
_Connector (accelerator = "cocofruit" )
Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import os
15
+ import warnings
15
16
from contextlib import nullcontext
16
17
from re import escape
17
18
from unittest import mock
@@ -735,6 +736,22 @@ def test_autocast():
735
736
fabric ._precision .forward_context ().__exit__ .assert_called ()
736
737
737
738
739
+ @RunIf (mps = True )
740
+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
741
+ def test_autocast_does_not_use_cuda_on_mps (precision ):
742
+ """Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision."""
743
+ fabric = Fabric (accelerator = "mps" , precision = precision )
744
+ fabric .launch ()
745
+
746
+ with warnings .catch_warnings (record = True ) as w :
747
+ warnings .simplefilter ("always" )
748
+ with fabric .autocast ():
749
+ pass
750
+
751
+ for warning in w :
752
+ assert "device_type of 'cuda'" not in str (warning .message )
753
+
754
+
738
755
def test_no_backward_sync ():
739
756
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
740
757
fabric = Fabric (devices = 1 )
You can’t perform that action at this time.
0 commit comments