Skip to content

Commit 7b0cc68

Browse files
committed
Add review updates
1 parent dafabe5 commit 7b0cc68

File tree

5 files changed

+72
-15
lines changed

5 files changed

+72
-15
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,7 @@ def transform(self,
12351235
sklearn_utils_validation.check_is_fitted(self, "n_features_")
12361236
self.solver_._check_is_session_id_valid(session_id=session_id)
12371237

1238-
if torch.is_tensor(X) and X.device.type == "cuda":
1238+
if torch.is_tensor(X):
12391239
X = X.detach().cpu()
12401240

12411241
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
@@ -1256,6 +1256,60 @@ def transform(self,
12561256

12571257
return output.detach().cpu().numpy()
12581258

1259+
# Deprecated, kept for testing.
1260+
def transform_deprecated(self,
1261+
X: Union[npt.NDArray, torch.Tensor],
1262+
session_id: Optional[int] = None) -> npt.NDArray:
1263+
"""Transform an input sequence and return the embedding.
1264+
1265+
Args:
1266+
X: A numpy array or torch tensor of size ``time x dimension``.
1267+
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
1268+
multisession, set to ``None`` for single session.
1269+
1270+
Returns:
1271+
A :py:func:`numpy.array` of size ``time x output_dimension``.
1272+
1273+
Example:
1274+
1275+
>>> import cebra
1276+
>>> import numpy as np
1277+
>>> dataset = np.random.uniform(0, 1, (1000, 30))
1278+
>>> cebra_model = cebra.CEBRA(max_iterations=10)
1279+
>>> cebra_model.fit(dataset)
1280+
CEBRA(max_iterations=10)
1281+
>>> embedding = cebra_model.transform(dataset)
1282+
1283+
"""
1284+
1285+
sklearn_utils_validation.check_is_fitted(self, "n_features_")
1286+
model, offset = self._select_model(X, session_id)
1287+
1288+
# Input validation
1289+
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1290+
input_dtype = X.dtype
1291+
1292+
with torch.no_grad():
1293+
model.eval()
1294+
1295+
if self.pad_before_transform:
1296+
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
1297+
mode="edge")
1298+
X = torch.from_numpy(X).float().to(self.device_)
1299+
1300+
if isinstance(model, cebra.models.ConvolutionalModelMixin):
1301+
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
1302+
X = X.transpose(1, 0).unsqueeze(0)
1303+
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
1304+
else:
1305+
# Standard evaluation, (T, C, dt)
1306+
output = model(X).cpu().numpy()
1307+
1308+
if input_dtype == "float64":
1309+
return output.astype(input_dtype)
1310+
1311+
return output
1312+
12591313
def fit_transform(
12601314
self,
12611315
X: Union[npt.NDArray, torch.Tensor],

cebra/solver/base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,6 @@ def fit(
452452
if logdir is not None:
453453
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
454454

455-
self._set_fitted_params(loader)
456-
457455
def step(self, batch: cebra.data.Batch) -> dict:
458456
"""Perform a single gradient update.
459457
@@ -603,9 +601,18 @@ def transform(self,
603601
Returns:
604602
The output embedding.
605603
"""
604+
if not self.is_fitted:
605+
raise ValueError(
606+
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
607+
"appropriate arguments before using this estimator.")
608+
609+
if batch_size is not None and batch_size < 1:
610+
raise ValueError(
611+
f"Batch size should be at least 1, got {batch_size}")
612+
606613
if isinstance(inputs, list):
607-
raise NotImplementedError(
608-
"Inputs to transform() should be the data for a single session."
614+
raise ValueError(
615+
"Inputs to transform() should be the data for a single session, but received a list."
609616
)
610617
elif not isinstance(inputs, torch.Tensor):
611618
raise ValueError(
@@ -673,7 +680,7 @@ def load(self, logdir, filename="checkpoint.pth"):
673680
session_n_features for session_n_features in n_features
674681
] if isinstance(n_features, list) else n_features)
675682

676-
def save(self, logdir, filename="checkpoint.pth"):
683+
def save(self, logdir, filename="checkpoint_last.pth"):
677684
"""Save the model and optimizer params.
678685
679686
Args:

cebra/solver/multi_session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ class MultiSessionSolver(abc_.Solver):
4141

4242
def parameters(self, session_id: Optional[int] = None):
4343
"""Iterate over all parameters."""
44-
self._check_is_session_id_valid(session_id=session_id)
45-
for parameter in self.model[session_id].parameters():
46-
yield parameter
44+
if session_id is not None:
45+
for parameter in self.model[session_id].parameters():
46+
yield parameter
4747

4848
for parameter in self.criterion.parameters():
4949
yield parameter

tests/test_sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1519,4 +1519,4 @@ def test_last_incomplete_batch_smaller_than_offset():
15191519
device="cpu")
15201520
model.fit(train.neural, train.continuous)
15211521

1522-
_ = model.transform(train.neural, batch_size=300)
1522+
_ = model.transform(train.neural, batch_size=300)

tests/test_solver.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def test_multi_session(data_name, loader_initfunc, model_architecture,
374374

375375
with pytest.raises(RuntimeError, match="No.*session_id"):
376376
embedding = solver.transform(X[0])
377-
with pytest.raises(RuntimeError, match="single.*session"):
377+
with pytest.raises(ValueError, match="single.*session"):
378378
embedding = solver.transform(X)
379379
with pytest.raises(RuntimeError, match="Invalid.*session_id"):
380380
embedding = solver.transform(X[0], session_id=5)
@@ -384,10 +384,6 @@ def test_multi_session(data_name, loader_initfunc, model_architecture,
384384
for param in solver.parameters(session_id=0):
385385
assert isinstance(param, torch.Tensor)
386386

387-
with pytest.raises(RuntimeError, match="No.*session_id"):
388-
for param in solver.parameters():
389-
assert isinstance(param, torch.Tensor)
390-
391387
fitted_solver = copy.deepcopy(solver)
392388
with tempfile.TemporaryDirectory() as temp_dir:
393389
solver.save(temp_dir)

0 commit comments

Comments
 (0)