Skip to content

Commit b707a5a

Browse files
authored
Fix Device in Multisession Training (#44)
* add device support to DatasetCollection * add tests * add device support to DatasetCollection * add tests * move helper function out of the class * fix typo in test * fix test * fix test_check_devices when cpu only available * Update docstring * Remove comment * Run pre-commit formatting
1 parent f3c4f0a commit b707a5a

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

cebra/data/datasets.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,25 @@ def __getitem__(self, index):
108108
return self.neural[index].transpose(2, 1)
109109

110110

111+
def _assert_datasets_same_device(
112+
datasets: List[cebra_data.SingleSessionDataset]) -> str:
113+
"""Checks if the list of datasets are all on the same device.
114+
115+
Args:
116+
datasets: List of datasets.
117+
118+
Returns:
119+
The device name if all datasets are on the same device.
120+
121+
Raises:
122+
ValueError: If datasets are not all on the same device.
123+
"""
124+
devices = set([dataset.device for dataset in datasets])
125+
if len(devices) != 1:
126+
raise ValueError("Datasets are not all on the same device")
127+
return devices.pop()
128+
129+
111130
class DatasetCollection(cebra_data.MultiSessionDataset):
112131
"""Multi session dataset made up of a list of datasets.
113132
@@ -165,11 +184,13 @@ def __init__(
165184
self,
166185
*datasets: cebra_data.SingleSessionDataset,
167186
):
168-
super().__init__()
169187
self._datasets: List[
170188
cebra_data.SingleSessionDataset] = self._unpack_dataset_arguments(
171189
datasets)
172190

191+
device = _assert_datasets_same_device(self._datasets)
192+
super().__init__(device=device)
193+
173194
continuous = all(
174195
self._has_not_none_attribute(session, "continuous_index")
175196
for session in self.iter_sessions())

tests/test_sklearn.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -758,12 +758,17 @@ def _iterate_actions():
758758
def do_nothing(model):
759759
return model
760760

761-
def fit_model(model):
761+
def fit_singlesession_model(model):
762762
X = np.linspace(-1, 1, 1000)[:, None]
763763
model.fit(X)
764764
return model
765765

766-
return [do_nothing, fit_model]
766+
def fit_multisession_model(model):
767+
X = np.linspace(-1, 1, 1000)[:, None]
768+
model.fit([X, X], [X, X])
769+
return model
770+
771+
return [do_nothing, fit_singlesession_model, fit_multisession_model]
767772

768773

769774
def _assert_same_state_dict(first, second):
@@ -797,17 +802,43 @@ def check_fitted(model):
797802
_assert_same_state_dict(original_model.state_dict_,
798803
loaded_model.state_dict_)
799804
X = np.random.normal(0, 1, (100, 1))
800-
assert np.allclose(loaded_model.transform(X),
801-
original_model.transform(X))
805+
806+
if loaded_model.num_sessions is not None:
807+
assert np.allclose(loaded_model.transform(X, session_id=0),
808+
original_model.transform(X, session_id=0))
809+
else:
810+
assert np.allclose(loaded_model.transform(X),
811+
original_model.transform(X))
802812

803813

804814
@pytest.mark.parametrize("action", _iterate_actions())
805815
def test_save_and_load(action):
806816
model_architecture = "offset10-model"
807817
original_model = cebra_sklearn_cebra.CEBRA(
808-
model_architecture=model_architecture, max_iterations=5)
818+
model_architecture=model_architecture, max_iterations=5, batch_size=42)
809819
original_model = action(original_model)
810820
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
811821
original_model.save(savefile.name)
812822
loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name)
813823
_assert_equal(original_model, loaded_model)
824+
825+
826+
@pytest.mark.parametrize("device", ["cpu"] +
827+
["cuda"] if torch.cuda.is_available() else [])
828+
@pytest.mark.parametrize("action", _iterate_actions())
829+
def test_check_devices(action, device):
830+
cebra_model = cebra_sklearn_cebra.CEBRA(
831+
model_architecture="offset1-model",
832+
max_iterations=5,
833+
device=device,
834+
batch_size=42,
835+
)
836+
cebra_model = action(cebra_model)
837+
assert cebra_model.device == device
838+
839+
if action.__name__ != "do_nothing":
840+
if device == "cuda":
841+
#TODO(rodrigo): remove once https://github.com/AdaptiveMotorControlLab/CEBRA/pull/34 is merged.
842+
device = torch.device(device, index=0)
843+
assert next(
844+
cebra_model.model_.parameters()).device == torch.device(device)

0 commit comments

Comments
 (0)