Skip to content

Commit eda4aa7

Browse files
Fix type bug in to() method (#55)
* fix type bug in to() method * added test for to() method * update tests --------- Co-authored-by: Rodrigo <[email protected]>
1 parent 808099b commit eda4aa7

File tree

2 files changed

+63
-27
lines changed

2 files changed

+63
-27
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,21 +1282,21 @@ def to(self, device: Union[str, torch.device]):
12821282
raise TypeError(
12831283
"The 'device' parameter must be a string or torch.device object."
12841284
)
1285-
1286-
if (not device == 'cpu') and (not device.startswith('cuda')) and (
1287-
not device == 'mps'):
1288-
raise ValueError(
1289-
"The 'device' parameter must be a valid device string or device object."
1290-
)
1291-
1285+
12921286
if isinstance(device, str):
1293-
device = torch.device(device)
1287+
if (not device == 'cpu') and (not device.startswith('cuda')) and (
1288+
not device == 'mps'):
1289+
raise ValueError(
1290+
"The 'device' parameter must be a valid device string or device object."
1291+
)
12941292

1295-
if (not device.type == 'cpu') and (
1296-
not device.type.startswith('cuda')) and (not device == 'mps'):
1297-
raise ValueError(
1298-
"The 'device' parameter must be a valid device string or device object."
1299-
)
1293+
elif isinstance(device, torch.device):
1294+
if (not device.type == 'cpu') and (
1295+
not device.type.startswith('cuda')) and (not device == 'mps'):
1296+
raise ValueError(
1297+
"The 'device' parameter must be a valid device string or device object."
1298+
)
1299+
device = device.type
13001300

13011301
if hasattr(self, "device_"):
13021302
self.device_ = device

tests/test_sklearn.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,39 @@ def get_ordered_cuda_devices():
856856
ordered_cuda_devices = get_ordered_cuda_devices() if torch.cuda.is_available(
857857
) else []
858858

859+
def test_fit_after_moving_to_device():
860+
expected_device = 'cpu'
861+
expected_type = type(expected_device)
862+
863+
X = np.random.uniform(0, 1, (10, 5))
864+
cebra_model = cebra_sklearn_cebra.CEBRA(model_architecture="offset1-model",
865+
max_iterations=5,
866+
device=expected_device)
867+
868+
assert type(cebra_model.device) == expected_type
869+
assert cebra_model.device == expected_device
870+
871+
cebra_model.partial_fit(X)
872+
assert type(cebra_model.device) == expected_type
873+
assert cebra_model.device == expected_device
874+
if hasattr(cebra_model, 'device_'):
875+
assert type(cebra_model.device_) == expected_type
876+
assert cebra_model.device_ == expected_device
877+
878+
# Move the model to device using the to() method
879+
cebra_model.to('cpu')
880+
assert type(cebra_model.device) == expected_type
881+
assert cebra_model.device == expected_device
882+
if hasattr(cebra_model, 'device_'):
883+
assert type(cebra_model.device_) == expected_type
884+
assert cebra_model.device_ == expected_device
885+
886+
cebra_model.partial_fit(X)
887+
assert type(cebra_model.device) == expected_type
888+
assert cebra_model.device == expected_device
889+
if hasattr(cebra_model, 'device_'):
890+
assert type(cebra_model.device_) == expected_type
891+
assert cebra_model.device_ == expected_device
859892

860893
@pytest.mark.parametrize("device", ['cpu'] + ordered_cuda_devices)
861894
def test_move_cpu_to_cuda_device(device):
@@ -875,9 +908,12 @@ def test_move_cpu_to_cuda_device(device):
875908
new_device = 'cpu' if device.startswith('cuda') else 'cuda:0'
876909
cebra_model.to(new_device)
877910

878-
assert cebra_model.device == torch.device(new_device)
879-
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
880-
new_device)
911+
assert cebra_model.device == new_device
912+
device_model = next(cebra_model.solver_.model.parameters()).device
913+
device_str = str(device_model)
914+
if device_model.type == 'cuda':
915+
device_str = f'cuda:{device_model.index}'
916+
assert device_str == new_device
881917

882918
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
883919
cebra_model.save(savefile.name)
@@ -903,9 +939,10 @@ def test_move_cpu_to_mps_device(device):
903939
new_device = 'cpu' if device == 'mps' else 'mps'
904940
cebra_model.to(new_device)
905941

906-
assert cebra_model.device == torch.device(new_device)
907-
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
908-
new_device)
942+
assert cebra_model.device == new_device
943+
944+
device_model = next(cebra_model.solver_.model.parameters()).device
945+
assert device_model.type == new_device
909946

910947
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
911948
cebra_model.save(savefile.name)
@@ -939,9 +976,12 @@ def test_move_mps_to_cuda_device(device):
939976
new_device = 'mps' if device.startswith('cuda') else 'cuda:0'
940977
cebra_model.to(new_device)
941978

942-
assert cebra_model.device == torch.device(new_device)
943-
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
944-
new_device)
979+
assert cebra_model.device == new_device
980+
device_model = next(cebra_model.solver_.model.parameters()).device
981+
device_str = str(device_model)
982+
if device_model.type == 'cuda':
983+
device_str = f'cuda:{device_model.index}'
984+
assert device_str == new_device
945985

946986
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
947987
cebra_model.save(savefile.name)
@@ -963,11 +1003,7 @@ def test_mps():
9631003

9641004
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
9651005
torch.backends.mps.is_available = lambda: False
966-
with pytest.raises(ValueError):
967-
cebra_model.fit(X)
968-
969-
torch.backends.mps.is_available = lambda: True
970-
torch.backends.mps.is_built = lambda: False
1006+
9711007
with pytest.raises(ValueError):
9721008
cebra_model.fit(X)
9731009

0 commit comments

Comments
 (0)