File tree Expand file tree Collapse file tree 7 files changed +52
-4
lines changed Expand file tree Collapse file tree 7 files changed +52
-4
lines changed Original file line number Diff line number Diff line change 9797 # adding dome more images as Thunder mainly using python 3.10,
9898 # and we need to support integrations as for example LitGPT
9999 python_version : ["3.10"]
100- pytorch_version : ["2.6.0", "2.7.0 "]
100+ pytorch_version : ["2.6.0", "2.7.1 "]
101101 cuda_version : ["12.6.3"]
102102 include :
103103 # These are the base images for PL release docker images.
@@ -108,7 +108,7 @@ jobs:
108108 - { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.1" }
109109 - { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.1" }
110110 - { python_version: "3.12", pytorch_version: "2.6.0", cuda_version: "12.4.1" }
111- - { python_version: "3.12", pytorch_version: "2.7.0 ", cuda_version: "12.6.3" }
111+ - { python_version: "3.12", pytorch_version: "2.7.1 ", cuda_version: "12.6.3" }
112112 steps :
113113 - uses : actions/checkout@v4
114114 - uses : docker/setup-buildx-action@v3
Original file line number Diff line number Diff line change 11mypy==1.15.0
2- torch==2.7.0
2+ torch==2.7.1
33
44types-Markdown
55types-PyYAML
Original file line number Diff line number Diff line change @@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision:
492492 if self ._precision_input == "16-mixed"
493493 else "Using bfloat16 Automatic Mixed Precision (AMP)"
494494 )
495- device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
495+ device = self . _accelerator_flag if self ._accelerator_flag in ( "cpu" , "mps" ) else "cuda"
496496 return MixedPrecision (precision = self ._precision_input , device = device ) # type: ignore[arg-type]
497497
498498 raise RuntimeError ("No precision set" )
Original file line number Diff line number Diff line change @@ -1486,6 +1486,10 @@ def forward(self, x):
14861486 )
14871487 example_inputs = self .example_input_array
14881488
1489+ if kwargs .get ("check_inputs" ) is not None :
1490+ kwargs ["check_inputs" ] = self ._on_before_batch_transfer (kwargs ["check_inputs" ])
1491+ kwargs ["check_inputs" ] = self ._apply_batch_transfer_handler (kwargs ["check_inputs" ])
1492+
14891493 # automatically send example inputs to the right device and use trace
14901494 example_inputs = self ._on_before_batch_transfer (example_inputs )
14911495 example_inputs = self ._apply_batch_transfer_handler (example_inputs )
Original file line number Diff line number Diff line change @@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
405405 assert isinstance (connector .strategy , DDPStrategy )
406406
407407
408+ @RunIf (mps = True )
409+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
410+ def test_mps_enabled_with_float16_or_bfloat16_precision (precision ):
411+ connector = _Connector (accelerator = "mps" , precision = precision )
412+ assert connector .precision .device == "mps"
413+
414+
408415def test_invalid_accelerator_choice ():
409416 with pytest .raises (ValueError , match = "You selected an invalid accelerator name: `accelerator='cocofruit'`" ):
410417 _Connector (accelerator = "cocofruit" )
Original file line number Diff line number Diff line change 1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15+ import warnings
1516from contextlib import nullcontext
1617from re import escape
1718from unittest import mock
@@ -735,6 +736,22 @@ def test_autocast():
735736 fabric ._precision .forward_context ().__exit__ .assert_called ()
736737
737738
739+ @RunIf (mps = True )
740+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
741+ def test_autocast_does_not_use_cuda_on_mps (precision ):
742+ """Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision."""
743+ fabric = Fabric (accelerator = "mps" , precision = precision )
744+ fabric .launch ()
745+
746+ with warnings .catch_warnings (record = True ) as w :
747+ warnings .simplefilter ("always" )
748+ with fabric .autocast ():
749+ pass
750+
751+ for warning in w :
752+ assert "device_type of 'cuda'" not in str (warning .message )
753+
754+
738755def test_no_backward_sync ():
739756 """Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
740757 fabric = Fabric (devices = 1 )
Original file line number Diff line number Diff line change @@ -105,6 +105,26 @@ def test_torchscript_device(device_str):
105105 assert script_output .device == device
106106
107107
108+ @pytest .mark .parametrize (
109+ "device_str" ,
110+ [
111+ "cpu" ,
112+ pytest .param ("cuda:0" , marks = RunIf (min_cuda_gpus = 1 )),
113+ pytest .param ("mps:0" , marks = RunIf (mps = True )),
114+ ],
115+ )
116+ def test_torchscript_device_with_check_inputs (device_str ):
117+ """Test that scripted module is on the correct device."""
118+ device = torch .device (device_str )
119+ model = BoringModel ().to (device )
120+ model .example_input_array = torch .randn (5 , 32 )
121+
122+ check_inputs = torch .rand (5 , 32 )
123+
124+ script = model .to_torchscript (method = "trace" , check_inputs = check_inputs )
125+ assert isinstance (script , torch .jit .ScriptModule )
126+
127+
108128def test_torchscript_retain_training_state ():
109129 """Test that torchscript export does not alter the training mode of original model."""
110130 model = BoringModel ()
You can’t perform that action at this time.
0 commit comments