Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
matrix:
include:
# only run PyTorch latest
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
matrix:
include:
# only run PyTorch latest
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
from functools import lru_cache
from typing import Optional, Union

Expand Down Expand Up @@ -72,7 +71,7 @@ def auto_device_count() -> int:
def is_available() -> bool:
"""MPS is only available on a machine with the ARM-based Apple Silicon processors."""
mps_disabled = os.getenv("DISABLE_MPS", "0") == "1"
return not mps_disabled and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")
return not mps_disabled and torch.backends.mps.is_available()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the issue with the current checks?

We cannot easily remove them. MPS marks as available on older intel-based macs as well while not actually accelerating anything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is why we add testing back with Intel to see what is happening and if we even have related tests to cover or refuse this change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if it passes, it is slower than CPU on intel-macs and due to auto-selection users would end up with a non-optimal backend.


@classmethod
@override
Expand Down
2 changes: 2 additions & 0 deletions tests/tests_pytorch/accelerators/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.runif import RunIf

# trigger pytorch test


@RunIf(mps=True)
def test_get_mps_stats():
Expand Down
Loading