Skip to content

Commit 00601fb

Browse files
authored
Improve device support and add support for Apple Silicon chipset (mps) (#34)
* add to() method to move cebra models (sklearn API) from devices * better name of test * assign self.device_ if it exists only * modify check_device() to allow GPU id specification * adapt test given the possibility of specifying GPU ids * add support for mps device * add mps to _set_device() in io * add mps logic when cuda_if_available + fix test for torch versions < 1.12 * fix test when cuda is not available * fix test when pytorch < 1.12
1 parent c95bd5a commit 00601fb

File tree

5 files changed

+345
-21
lines changed

5 files changed

+345
-21
lines changed

cebra/helper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import numpy as np
2323
import numpy.typing as npt
24+
import pkg_resources
2425
import requests
2526
import torch
2627

@@ -61,6 +62,17 @@ def download_file_from_zip_url(url, file="montblanc_tracks.h5"):
6162
return pathlib.Path(foldername) / "data" / file
6263

6364

65+
def _is_mps_availabe(torch):
66+
available = False
67+
if pkg_resources.parse_version(
68+
torch.__version__) >= pkg_resources.parse_version("1.12"):
69+
if torch.backends.mps.is_available():
70+
if torch.backends.mps.is_built():
71+
available = True
72+
73+
return available
74+
75+
6476
def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
6577
"""Check if the values in ``y`` are :py:class:`int`.
6678

cebra/integrations/sklearn/cebra.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,3 +1256,52 @@ def load(cls,
12561256
raise RuntimeError("Model loaded from file is not compatible with "
12571257
"the current CEBRA version.")
12581258
return model
1259+
1260+
def to(self, device: Union[str, torch.device]):
1261+
"""Moves the cebra model to the specified device.
1262+
1263+
Args:
1264+
device: The device to move the cebra model to. This can be a string representing
1265+
the device ('cpu','cuda', cuda:device_id, or 'mps') or a torch.device object.
1266+
1267+
Returns:
1268+
The cebra model instance.
1269+
1270+
Example:
1271+
1272+
>>> import cebra
1273+
>>> import numpy as np
1274+
>>> dataset = np.random.uniform(0, 1, (1000, 30))
1275+
>>> cebra_model = cebra.CEBRA(max_iterations=10, device = "cuda_if_available")
1276+
>>> cebra_model.fit(dataset)
1277+
CEBRA(max_iterations=10)
1278+
>>> cebra_model = cebra_model.to("cpu")
1279+
"""
1280+
1281+
if not isinstance(device, (str, torch.device)):
1282+
raise TypeError(
1283+
"The 'device' parameter must be a string or torch.device object."
1284+
)
1285+
1286+
if (not device == 'cpu') and (not device.startswith('cuda')) and (
1287+
not device == 'mps'):
1288+
raise ValueError(
1289+
"The 'device' parameter must be a valid device string or device object."
1290+
)
1291+
1292+
if isinstance(device, str):
1293+
device = torch.device(device)
1294+
1295+
if (not device.type == 'cpu') and (
1296+
not device.type.startswith('cuda')) and (not device == 'mps'):
1297+
raise ValueError(
1298+
"The 'device' parameter must be a valid device string or device object."
1299+
)
1300+
1301+
if hasattr(self, "device_"):
1302+
self.device_ = device
1303+
1304+
self.device = device
1305+
self.solver_.model.to(device)
1306+
1307+
return self

cebra/integrations/sklearn/utils.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import sklearn.utils.validation as sklearn_utils_validation
1616
import torch
1717

18+
import cebra.helper
19+
1820

1921
def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
2022
"""Handle deprecated arguments of a function until they are replaced.
@@ -114,16 +116,50 @@ def check_device(device: str) -> str:
114116
device: The device to return, if possible.
115117
116118
Returns:
117-
Either cuda or cpu depending on {device} and availability in the environment.
119+
Either cuda, cuda:device_id, mps, or cpu depending on {device} and availability in the environment.
118120
"""
121+
119122
if device == "cuda_if_available":
120123
if torch.cuda.is_available():
121124
return "cuda"
125+
elif cebra.helper._is_mps_availabe(torch):
126+
return "mps"
122127
else:
123128
return "cpu"
124-
elif device in ["cuda", "cpu"]:
129+
elif device.startswith("cuda:") and len(device) > 5:
130+
cuda_device_id = device[5:]
131+
if cuda_device_id.isdigit():
132+
device_count = torch.cuda.device_count()
133+
device_id = int(cuda_device_id)
134+
if device_id < device_count:
135+
return f"cuda:{device_id}"
136+
else:
137+
raise ValueError(
138+
f"CUDA device {device_id} is not available. Available device IDs are 0 to {device_count - 1}."
139+
)
140+
else:
141+
raise ValueError(
142+
f"Invalid CUDA device ID format. Please use 'cuda:device_id' where '{cuda_device_id}' is an integer."
143+
)
144+
elif device == "cuda" and torch.cuda.is_available():
145+
return "cuda:0"
146+
elif device == "cpu":
125147
return device
126-
raise ValueError(f"Device needs to be cuda or cpu, but got {device}.")
148+
elif device == "mps":
149+
if not torch.backends.mps.is_available():
150+
if not torch.backends.mps.is_built():
151+
raise ValueError(
152+
"MPS not available because the current PyTorch install was not "
153+
"built with MPS enabled.")
154+
else:
155+
raise ValueError(
156+
"MPS not available because the current MacOS version is not 12.3+ "
157+
"and/or you do not have an MPS-enabled device on this machine."
158+
)
159+
160+
return device
161+
162+
raise ValueError(f"Device needs to be cuda, cpu or mps, but got {device}.")
127163

128164

129165
def check_fitted(model: "cebra.models.Model") -> bool:

cebra/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _set_device(self, device):
7575
return
7676
if not isinstance(device, str):
7777
device = device.type
78-
if device not in ("cpu", "cuda"):
78+
if device not in ("cpu", "cuda", "mps"):
7979
if device.startswith("cuda"):
8080
_, id_ = device.split(":")
8181
if int(id_) >= torch.cuda.device_count():

0 commit comments

Comments
 (0)