Skip to content

Commit c2544c7

Browse files
committed
Add review updates
1 parent 7f58607 commit c2544c7

File tree

6 files changed

+300
-28
lines changed

6 files changed

+300
-28
lines changed

cebra/data/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def configure_for(self, model: "cebra.models.Model"):
207207
model: The model to configure the dataset for.
208208
"""
209209
raise NotImplementedError
210-
self.offset = model.get_offset()
211210

212211

213212
@dataclasses.dataclass

cebra/integrations/sklearn/cebra.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,14 +1202,18 @@ def transform(self,
12021202
sklearn_utils_validation.check_is_fitted(self, "n_features_")
12031203
self.solver_._check_is_session_id_valid(session_id=session_id)
12041204

1205-
if torch.is_tensor(X) and X.device.type == "cuda":
1205+
if torch.is_tensor(X):
12061206
X = X.detach().cpu()
12071207

12081208
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
12091209

12101210
if isinstance(X, np.ndarray):
12111211
X = torch.from_numpy(X)
12121212

1213+
if batch_size is not None and batch_size < 1:
1214+
raise ValueError(
1215+
f"Batch size should be at least 1, got {batch_size}")
1216+
12131217
with torch.no_grad():
12141218
output = self.solver_.transform(
12151219
inputs=X,
@@ -1219,6 +1223,60 @@ def transform(self,
12191223

12201224
return output.detach().cpu().numpy()
12211225

1226+
# Deprecated, kept for testing.
1227+
def transform_deprecated(self,
1228+
X: Union[npt.NDArray, torch.Tensor],
1229+
session_id: Optional[int] = None) -> npt.NDArray:
1230+
"""Transform an input sequence and return the embedding.
1231+
1232+
Args:
1233+
X: A numpy array or torch tensor of size ``time x dimension``.
1234+
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
1235+
multisession, set to ``None`` for single session.
1236+
1237+
Returns:
1238+
A :py:func:`numpy.array` of size ``time x output_dimension``.
1239+
1240+
Example:
1241+
1242+
>>> import cebra
1243+
>>> import numpy as np
1244+
>>> dataset = np.random.uniform(0, 1, (1000, 30))
1245+
>>> cebra_model = cebra.CEBRA(max_iterations=10)
1246+
>>> cebra_model.fit(dataset)
1247+
CEBRA(max_iterations=10)
1248+
>>> embedding = cebra_model.transform(dataset)
1249+
1250+
"""
1251+
1252+
sklearn_utils_validation.check_is_fitted(self, "n_features_")
1253+
model, offset = self._select_model(X, session_id)
1254+
1255+
# Input validation
1256+
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1257+
input_dtype = X.dtype
1258+
1259+
with torch.no_grad():
1260+
model.eval()
1261+
1262+
if self.pad_before_transform:
1263+
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
1264+
mode="edge")
1265+
X = torch.from_numpy(X).float().to(self.device_)
1266+
1267+
if isinstance(model, cebra.models.ConvolutionalModelMixin):
1268+
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
1269+
X = X.transpose(1, 0).unsqueeze(0)
1270+
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
1271+
else:
1272+
# Standard evaluation, (T, C, dt)
1273+
output = model(X).cpu().numpy()
1274+
1275+
if input_dtype == "float64":
1276+
return output.astype(input_dtype)
1277+
1278+
return output
1279+
12221280
def fit_transform(
12231281
self,
12241282
X: Union[npt.NDArray, torch.Tensor],

cebra/solver/base.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,17 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int,
8181
f"batch_end_idx ({batch_end_idx}) cannot exceed the length of inputs ({num_samples})."
8282
)
8383

84-
batch_size_lenght = batch_end_idx - batch_start_idx
85-
if batch_size_lenght <= len(offset):
84+
batch_size_length = batch_end_idx - batch_start_idx
85+
if batch_size_length <= len(offset):
8686
raise ValueError(
87-
f"The batch has length {batch_size_lenght} which "
87+
f"The batch has length {batch_size_length} which "
8888
f"is smaller or equal than the required offset length {len(offset)}."
8989
f"Either choose a model with smaller offset or the batch should contain more samples."
9090
)
9191

9292

9393
def _add_batched_zero_padding(batched_data: torch.Tensor,
94-
offset: cebra.data.Offset,
95-
batch_start_idx: int,
94+
offset: cebra.data.Offset, batch_start_idx: int,
9695
batch_end_idx: int,
9796
num_samples: int) -> torch.Tensor:
9897
"""Add zero padding to the input data before inference.
@@ -409,6 +408,7 @@ def fit(
409408
TODO:
410409
* Refine the API here. Drop the validation entirely, and implement this via a hook?
411410
"""
411+
self._set_fitted_params(loader)
412412
self.to(loader.device)
413413

414414
iterator = self._get_loader(loader)
@@ -436,8 +436,6 @@ def fit(
436436
save_hook(num_steps, self)
437437
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
438438

439-
self._set_fitted_params(loader)
440-
441439
def step(self, batch: cebra.data.Batch) -> dict:
442440
"""Perform a single gradient update.
443441
@@ -553,6 +551,10 @@ def _select_model(
553551
"""
554552
raise NotImplementedError
555553

554+
@property
555+
def is_fitted(self):
556+
return hasattr(self, "n_features")
557+
556558
@torch.no_grad()
557559
def transform(self,
558560
inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray],
@@ -579,19 +581,24 @@ def transform(self,
579581
Returns:
580582
The output embedding.
581583
"""
584+
if not self.is_fitted:
585+
raise ValueError(
586+
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
587+
"appropriate arguments before using this estimator.")
588+
589+
if batch_size is not None and batch_size < 1:
590+
raise ValueError(
591+
f"Batch size should be at least 1, got {batch_size}")
592+
582593
if isinstance(inputs, list):
583-
raise NotImplementedError(
584-
"Inputs to transform() should be the data for a single session."
594+
raise ValueError(
595+
"Inputs to transform() should be the data for a single session, but received a list."
585596
)
586597

587598
elif not isinstance(inputs, torch.Tensor):
588599
raise ValueError(
589600
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
590601

591-
if not hasattr(self, "n_features"):
592-
raise ValueError(
593-
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
594-
"appropriate arguments before using this estimator.")
595602
model, offset = self._select_model(inputs, session_id)
596603

597604
if len(offset) < 2 and pad_before_transform:
@@ -647,7 +654,7 @@ def load(self, logdir, filename="checkpoint.pth"):
647654
checkpoint = torch.load(savepath, map_location=self.device)
648655
self.load_state_dict(checkpoint, strict=True)
649656

650-
def save(self, logdir, filename="checkpoint.pth"):
657+
def save(self, logdir, filename="checkpoint_last.pth"):
651658
"""Save the model and optimizer params.
652659
653660
Args:

cebra/solver/multi_session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ class MultiSessionSolver(abc_.Solver):
4444

4545
def parameters(self, session_id: Optional[int] = None):
4646
"""Iterate over all parameters."""
47-
self._check_is_session_id_valid(session_id=session_id)
48-
for parameter in self.model[session_id].parameters():
49-
yield parameter
47+
if session_id is not None:
48+
for parameter in self.model[session_id].parameters():
49+
yield parameter
5050

5151
for parameter in self.criterion.parameters():
5252
yield parameter

0 commit comments

Comments
 (0)