Skip to content

Commit d71ca8d

Browse files
committed
Improve modularity remove duplicate code and todos
1 parent b73c123 commit d71ca8d

File tree

3 files changed

+12
-63
lines changed

3 files changed

+12
-63
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,7 @@ def transform(self,
12351235
sklearn_utils_validation.check_is_fitted(self, "n_features_")
12361236
self.solver_._check_is_session_id_valid(session_id=session_id)
12371237

1238-
if torch.is_tensor(X):
1238+
if torch.is_tensor(X) and X.device.type == "cuda":
12391239
X = X.detach().cpu()
12401240

12411241
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
@@ -1256,61 +1256,6 @@ def transform(self,
12561256

12571257
return output.detach().cpu().numpy()
12581258

1259-
# Deprecated, kept for testing.
1260-
def transform_deprecated(self,
1261-
X: Union[npt.NDArray, torch.Tensor],
1262-
session_id: Optional[int] = None) -> npt.NDArray:
1263-
"""Transform an input sequence and return the embedding.
1264-
1265-
Args:
1266-
X: A numpy array or torch tensor of size ``time x dimension``.
1267-
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
1268-
multisession, set to ``None`` for single session.
1269-
1270-
Returns:
1271-
A :py:func:`numpy.array` of size ``time x output_dimension``.
1272-
1273-
Example:
1274-
1275-
>>> import cebra
1276-
>>> import numpy as np
1277-
>>> dataset = np.random.uniform(0, 1, (1000, 30))
1278-
>>> cebra_model = cebra.CEBRA(max_iterations=10)
1279-
>>> cebra_model.fit(dataset)
1280-
CEBRA(max_iterations=10)
1281-
>>> embedding = cebra_model.transform(dataset)
1282-
1283-
"""
1284-
1285-
sklearn_utils_validation.check_is_fitted(self, "n_features_")
1286-
model, offset = self._select_model(X, session_id)
1287-
1288-
# Input validation
1289-
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1290-
#input_dtype = X.dtype
1291-
1292-
with torch.no_grad():
1293-
model.eval()
1294-
1295-
if self.pad_before_transform:
1296-
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
1297-
mode="edge")
1298-
X = torch.from_numpy(X).float().to(self.device_)
1299-
1300-
if isinstance(model, cebra.models.ConvolutionalModelMixin):
1301-
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
1302-
X = X.transpose(1, 0).unsqueeze(0)
1303-
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
1304-
else:
1305-
# Standard evaluation, (T, C, dt)
1306-
output = model(X).cpu().numpy()
1307-
1308-
#TODO: check if this is safe.
1309-
return output.numpy(force=True)
1310-
1311-
#if input_dtype == "float64":
1312-
# return output.astype(input_dtype)
1313-
13141259
def fit_transform(
13151260
self,
13161261
X: Union[npt.NDArray, torch.Tensor],

cebra/solver/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,8 @@ def fit(
451451
if logdir is not None:
452452
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
453453

454+
self._set_fitted_params(loader)
455+
454456
def step(self, batch: cebra.data.Batch) -> dict:
455457
"""Perform a single gradient update.
456458
@@ -610,10 +612,6 @@ def transform(self,
610612

611613
if len(offset) < 2 and pad_before_transform:
612614
pad_before_transform = False
613-
614-
if batch_size is not None and batch_size < 1:
615-
raise ValueError(
616-
f"Batch size should be at least 1, got {batch_size}")
617615

618616
model.eval()
619617
if batch_size is not None:
@@ -665,6 +663,12 @@ def load(self, logdir, filename="checkpoint.pth"):
665663
checkpoint = torch.load(savepath, map_location=self.device)
666664
self.load_state_dict(checkpoint, strict=True)
667665

666+
if hasattr(self.model, "n_features"):
667+
n_features = self.model.n_features
668+
self.n_features = ([
669+
session_n_features for session_n_features in n_features
670+
] if isinstance(n_features, list) else n_features)
671+
668672
def save(self, logdir, filename="checkpoint_last.pth"):
669673
"""Save the model and optimizer params.
670674

cebra/solver/multi_session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ class MultiSessionSolver(abc_.Solver):
4141

4242
def parameters(self, session_id: Optional[int] = None):
4343
"""Iterate over all parameters."""
44-
if session_id is not None:
45-
for parameter in self.model[session_id].parameters():
46-
yield parameter
44+
self._check_is_session_id_valid(session_id=session_id)
45+
for parameter in self.model[session_id].parameters():
46+
yield parameter
4747

4848
for parameter in self.criterion.parameters():
4949
yield parameter

0 commit comments

Comments
 (0)