Skip to content

Commit a9478dc

Browse files
Merge branch 'Lightning-AI:master' into master
2 parents 0773eb4 + ab7b118 commit a9478dc

File tree

7 files changed

+52
-4
lines changed

7 files changed

+52
-4
lines changed

.github/workflows/docker-build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
# adding dome more images as Thunder mainly using python 3.10,
9898
# and we need to support integrations as for example LitGPT
9999
python_version: ["3.10"]
100-
pytorch_version: ["2.6.0", "2.7.0"]
100+
pytorch_version: ["2.6.0", "2.7.1"]
101101
cuda_version: ["12.6.3"]
102102
include:
103103
# These are the base images for PL release docker images.
@@ -108,7 +108,7 @@ jobs:
108108
- { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.1" }
109109
- { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.1" }
110110
- { python_version: "3.12", pytorch_version: "2.6.0", cuda_version: "12.4.1" }
111-
- { python_version: "3.12", pytorch_version: "2.7.0", cuda_version: "12.6.3" }
111+
- { python_version: "3.12", pytorch_version: "2.7.1", cuda_version: "12.6.3" }
112112
steps:
113113
- uses: actions/checkout@v4
114114
- uses: docker/setup-buildx-action@v3

requirements/typing.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mypy==1.15.0
2-
torch==2.7.0
2+
torch==2.7.1
33

44
types-Markdown
55
types-PyYAML

src/lightning/fabric/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision:
492492
if self._precision_input == "16-mixed"
493493
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494494
)
495-
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
495+
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
496496
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
497497

498498
raise RuntimeError("No precision set")

src/lightning/pytorch/core/module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,10 @@ def forward(self, x):
14721472
)
14731473
example_inputs = self.example_input_array
14741474

1475+
if kwargs.get("check_inputs") is not None:
1476+
kwargs["check_inputs"] = self._on_before_batch_transfer(kwargs["check_inputs"])
1477+
kwargs["check_inputs"] = self._apply_batch_transfer_handler(kwargs["check_inputs"])
1478+
14751479
# automatically send example inputs to the right device and use trace
14761480
example_inputs = self._on_before_batch_transfer(example_inputs)
14771481
example_inputs = self._apply_batch_transfer_handler(example_inputs)

tests/tests_fabric/test_connector.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
405405
assert isinstance(connector.strategy, DDPStrategy)
406406

407407

408+
@RunIf(mps=True)
409+
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
410+
def test_mps_enabled_with_float16_or_bfloat16_precision(precision):
411+
connector = _Connector(accelerator="mps", precision=precision)
412+
assert connector.precision.device == "mps"
413+
414+
408415
def test_invalid_accelerator_choice():
409416
with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='cocofruit'`"):
410417
_Connector(accelerator="cocofruit")

tests/tests_fabric/test_fabric.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import warnings
1516
from contextlib import nullcontext
1617
from re import escape
1718
from unittest import mock
@@ -735,6 +736,22 @@ def test_autocast():
735736
fabric._precision.forward_context().__exit__.assert_called()
736737

737738

739+
@RunIf(mps=True)
740+
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
741+
def test_autocast_does_not_use_cuda_on_mps(precision):
742+
"""Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision."""
743+
fabric = Fabric(accelerator="mps", precision=precision)
744+
fabric.launch()
745+
746+
with warnings.catch_warnings(record=True) as w:
747+
warnings.simplefilter("always")
748+
with fabric.autocast():
749+
pass
750+
751+
for warning in w:
752+
assert "device_type of 'cuda'" not in str(warning.message)
753+
754+
738755
def test_no_backward_sync():
739756
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
740757
fabric = Fabric(devices=1)

tests/tests_pytorch/models/test_torchscript.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,26 @@ def test_torchscript_device(device_str):
105105
assert script_output.device == device
106106

107107

108+
@pytest.mark.parametrize(
109+
"device_str",
110+
[
111+
"cpu",
112+
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
113+
pytest.param("mps:0", marks=RunIf(mps=True)),
114+
],
115+
)
116+
def test_torchscript_device_with_check_inputs(device_str):
117+
"""Test that scripted module is on the correct device."""
118+
device = torch.device(device_str)
119+
model = BoringModel().to(device)
120+
model.example_input_array = torch.randn(5, 32)
121+
122+
check_inputs = torch.rand(5, 32)
123+
124+
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
125+
assert isinstance(script, torch.jit.ScriptModule)
126+
127+
108128
def test_torchscript_retain_training_state():
109129
"""Test that torchscript export does not alter the training mode of original model."""
110130
model = BoringModel()

0 commit comments

Comments
 (0)