File tree Expand file tree Collapse file tree 7 files changed +52
-4
lines changed Expand file tree Collapse file tree 7 files changed +52
-4
lines changed Original file line number Diff line number Diff line change 97
97
# adding dome more images as Thunder mainly using python 3.10,
98
98
# and we need to support integrations as for example LitGPT
99
99
python_version : ["3.10"]
100
- pytorch_version : ["2.6.0", "2.7.0 "]
100
+ pytorch_version : ["2.6.0", "2.7.1 "]
101
101
cuda_version : ["12.6.3"]
102
102
include :
103
103
# These are the base images for PL release docker images.
@@ -108,7 +108,7 @@ jobs:
108
108
- { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.1" }
109
109
- { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.1" }
110
110
- { python_version: "3.12", pytorch_version: "2.6.0", cuda_version: "12.4.1" }
111
- - { python_version: "3.12", pytorch_version: "2.7.0 ", cuda_version: "12.6.3" }
111
+ - { python_version: "3.12", pytorch_version: "2.7.1 ", cuda_version: "12.6.3" }
112
112
steps :
113
113
- uses : actions/checkout@v4
114
114
- uses : docker/setup-buildx-action@v3
Original file line number Diff line number Diff line change 1
1
mypy==1.15.0
2
- torch==2.7.0
2
+ torch==2.7.1
3
3
4
4
types-Markdown
5
5
types-PyYAML
Original file line number Diff line number Diff line change @@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision:
492
492
if self ._precision_input == "16-mixed"
493
493
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494
494
)
495
- device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
495
+ device = self . _accelerator_flag if self ._accelerator_flag in ( "cpu" , "mps" ) else "cuda"
496
496
return MixedPrecision (precision = self ._precision_input , device = device ) # type: ignore[arg-type]
497
497
498
498
raise RuntimeError ("No precision set" )
Original file line number Diff line number Diff line change @@ -1472,6 +1472,10 @@ def forward(self, x):
1472
1472
)
1473
1473
example_inputs = self .example_input_array
1474
1474
1475
+ if kwargs .get ("check_inputs" ) is not None :
1476
+ kwargs ["check_inputs" ] = self ._on_before_batch_transfer (kwargs ["check_inputs" ])
1477
+ kwargs ["check_inputs" ] = self ._apply_batch_transfer_handler (kwargs ["check_inputs" ])
1478
+
1475
1479
# automatically send example inputs to the right device and use trace
1476
1480
example_inputs = self ._on_before_batch_transfer (example_inputs )
1477
1481
example_inputs = self ._apply_batch_transfer_handler (example_inputs )
Original file line number Diff line number Diff line change @@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
405
405
assert isinstance (connector .strategy , DDPStrategy )
406
406
407
407
408
+ @RunIf (mps = True )
409
+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
410
+ def test_mps_enabled_with_float16_or_bfloat16_precision (precision ):
411
+ connector = _Connector (accelerator = "mps" , precision = precision )
412
+ assert connector .precision .device == "mps"
413
+
414
+
408
415
def test_invalid_accelerator_choice ():
409
416
with pytest .raises (ValueError , match = "You selected an invalid accelerator name: `accelerator='cocofruit'`" ):
410
417
_Connector (accelerator = "cocofruit" )
Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import os
15
+ import warnings
15
16
from contextlib import nullcontext
16
17
from re import escape
17
18
from unittest import mock
@@ -735,6 +736,22 @@ def test_autocast():
735
736
fabric ._precision .forward_context ().__exit__ .assert_called ()
736
737
737
738
739
+ @RunIf (mps = True )
740
+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
741
+ def test_autocast_does_not_use_cuda_on_mps (precision ):
742
+ """Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision."""
743
+ fabric = Fabric (accelerator = "mps" , precision = precision )
744
+ fabric .launch ()
745
+
746
+ with warnings .catch_warnings (record = True ) as w :
747
+ warnings .simplefilter ("always" )
748
+ with fabric .autocast ():
749
+ pass
750
+
751
+ for warning in w :
752
+ assert "device_type of 'cuda'" not in str (warning .message )
753
+
754
+
738
755
def test_no_backward_sync ():
739
756
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
740
757
fabric = Fabric (devices = 1 )
Original file line number Diff line number Diff line change @@ -105,6 +105,26 @@ def test_torchscript_device(device_str):
105
105
assert script_output .device == device
106
106
107
107
108
+ @pytest .mark .parametrize (
109
+ "device_str" ,
110
+ [
111
+ "cpu" ,
112
+ pytest .param ("cuda:0" , marks = RunIf (min_cuda_gpus = 1 )),
113
+ pytest .param ("mps:0" , marks = RunIf (mps = True )),
114
+ ],
115
+ )
116
+ def test_torchscript_device_with_check_inputs (device_str ):
117
+ """Test that scripted module is on the correct device."""
118
+ device = torch .device (device_str )
119
+ model = BoringModel ().to (device )
120
+ model .example_input_array = torch .randn (5 , 32 )
121
+
122
+ check_inputs = torch .rand (5 , 32 )
123
+
124
+ script = model .to_torchscript (method = "trace" , check_inputs = check_inputs )
125
+ assert isinstance (script , torch .jit .ScriptModule )
126
+
127
+
108
128
def test_torchscript_retain_training_state ():
109
129
"""Test that torchscript export does not alter the training mode of original model."""
110
130
model = BoringModel ()
You can’t perform that action at this time.
0 commit comments