Skip to content

Commit 9fc30da

Browse files
committed
implement MPSAccelerator.device_name
1 parent 10a95b4 commit 9fc30da

File tree

1 file changed

+12
-2
lines changed
  • src/lightning/pytorch/accelerators

1 file changed

+12
-2
lines changed

src/lightning/pytorch/accelerators/mps.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import subprocess
1415
from typing import Any, Optional, Union
1516

1617
import torch
@@ -90,10 +91,19 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
9091
@classmethod
9192
@override
9293
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
93-
# todo: implement a better way to get the device name
9494
if not cls.is_available():
9595
return ""
96-
return "True (mps)"
96+
try:
97+
result = subprocess.run(
98+
["sysctl", "-n", "machdep.cpu.brand_string"],
99+
capture_output=True,
100+
text=True,
101+
check=True,
102+
)
103+
result_str = result.stdout.strip()
104+
except subprocess.SubprocessError:
105+
result_str = "True (mps)"
106+
return result_str
97107

98108

99109
# device metrics

0 commit comments

Comments
 (0)