Skip to content

Commit c5dc011

Browse files
committed
Implement reviews on tests and typing
1 parent 0d56e44 commit c5dc011

File tree

11 files changed

+191
-147
lines changed

11 files changed

+191
-147
lines changed

cebra/data/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def configure_for(self, model: "cebra.models.Model"):
197197
"""Configure the dataset offset for the provided model.
198198
199199
Call this function before indexing the dataset. This sets the
200-
``offset`` attribute of the dataset.
200+
:py:attr:`offset` attribute of the dataset.
201201
202202
Args:
203203
model: The model to configure the dataset for.

cebra/data/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def configure_for(self, model: "Model"):
353353
"""Configure the dataset offset for the provided model.
354354
355355
Call this function before indexing the dataset. This sets the
356-
``offset`` attribute of the dataset.
356+
:py:attr:`offset` attribute of the dataset.
357357
358358
Args:
359359
model: The model to configure the dataset for.

cebra/data/multi_session.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import cebra.distributions
3232
from cebra.data.datatypes import Batch
3333
from cebra.data.datatypes import BatchIndex
34+
from cebra.models import Model
3435

3536
__all__ = [
3637
"MultiSessionDataset",
@@ -104,17 +105,18 @@ def load_batch(self, index: BatchIndex) -> List[Batch]:
104105
) for session_id, session in enumerate(self.iter_sessions())
105106
]
106107

107-
def configure_for(self, model: "cebra.models.Model"):
108+
def configure_for(self, model: "Model"):
108109
"""Configure the dataset offset for the provided model.
109110
110111
Call this function before indexing the dataset. This sets the
111-
``offset`` attribute of the dataset.
112+
:py:attr:`cebra.data.Dataset.offset` attribute of the dataset.
112113
113114
Args:
114115
model: The model to configure the dataset for.
115116
"""
116-
for i, session in enumerate(self.iter_sessions()):
117-
session.configure_for(model[i])
117+
self.offset = model.get_offset()
118+
for session in self.iter_sessions():
119+
session.configure_for(model)
118120

119121

120122
@dataclasses.dataclass

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""Define the CEBRA model."""
2323

2424
import itertools
25-
import warnings
2625
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
2726
Union)
2827

@@ -687,7 +686,7 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
687686
if not _are_sessions_equal(X, y):
688687
raise ValueError(
689688
"Invalid number of samples or labels sessions: provide one session for single-session training, "
690-
"and make sure the number of samples in X and y need match, "
689+
"and make sure the number of samples in X and y match, "
691690
f"got {len(X)} and {[len(y_i) for y_i in y]}.")
692691
is_multisession = False
693692
dataset = _get_dataset(X, y)
@@ -1255,67 +1254,6 @@ def transform(self,
12551254

12561255
return output.detach().cpu().numpy()
12571256

1258-
#NOTE: Deprecated: transform is now handled in the solver but the original
1259-
# method is kept here for testing.
1260-
def transform_deprecated(self,
1261-
X: Union[npt.NDArray, torch.Tensor],
1262-
session_id: Optional[int] = None) -> npt.NDArray:
1263-
"""Transform an input sequence and return the embedding.
1264-
1265-
Args:
1266-
X: A numpy array or torch tensor of size ``time x dimension``.
1267-
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
1268-
multisession, set to ``None`` for single session.
1269-
1270-
Returns:
1271-
A :py:func:`numpy.array` of size ``time x output_dimension``.
1272-
1273-
Example:
1274-
1275-
>>> import cebra
1276-
>>> import numpy as np
1277-
>>> dataset = np.random.uniform(0, 1, (1000, 30))
1278-
>>> cebra_model = cebra.CEBRA(max_iterations=10)
1279-
>>> cebra_model.fit(dataset)
1280-
CEBRA(max_iterations=10)
1281-
>>> embedding = cebra_model.transform(dataset)
1282-
1283-
"""
1284-
warnings.warn(
1285-
"The method `transform_deprecated` is deprecated "
1286-
"but kept for testing puroposes."
1287-
"We recommend using `transform` instead.",
1288-
DeprecationWarning,
1289-
stacklevel=2)
1290-
1291-
sklearn_utils_validation.check_is_fitted(self, "n_features_")
1292-
model, offset = self._select_model(X, session_id)
1293-
1294-
# Input validation
1295-
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1296-
input_dtype = X.dtype
1297-
1298-
with torch.no_grad():
1299-
model.eval()
1300-
1301-
if self.pad_before_transform:
1302-
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
1303-
mode="edge")
1304-
X = torch.from_numpy(X).float().to(self.device_)
1305-
1306-
if isinstance(model, cebra.models.ConvolutionalModelMixin):
1307-
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
1308-
X = X.transpose(1, 0).unsqueeze(0)
1309-
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
1310-
else:
1311-
# Standard evaluation, (T, C, dt)
1312-
output = model(X).cpu().numpy()
1313-
1314-
if input_dtype == "float64":
1315-
return output.astype(input_dtype)
1316-
1317-
return output
1318-
13191257
def fit_transform(
13201258
self,
13211259
X: Union[npt.NDArray, torch.Tensor],

cebra/solver/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
def _check_indices(batch_start_idx: int, batch_end_idx: int,
5454
offset: cebra.data.Offset, num_samples: int):
55-
"""Check that indexes in a batch are in a correct range.
55+
"""Check that indices in a batch are in a correct range.
5656
5757
First and last index must be positive integers, smaller than
5858
the total length of inputs in the dataset, the first index

cebra/solver/multiobjective.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -456,56 +456,6 @@ def validation(
456456
self.log.setdefault(("sum_loss_val",), []).append(sum_loss_valid)
457457
return stats_val
458458

459-
# NOTE: Deprecated: batched transform can now be performed (more memory efficient)
460-
# using the transform method of the model, and handling padding is implemented
461-
# directly in the base Solver. This method is kept for testing purposes.
462-
@torch.no_grad()
463-
def transform_deprecated(self, inputs: torch.Tensor) -> torch.Tensor:
464-
"""Transform the input data using the model.
465-
466-
Args:
467-
inputs: The input data to transform.
468-
469-
Returns:
470-
The transformed data.
471-
"""
472-
473-
warnings.warn(
474-
"The method `transform_deprecated` is deprecated "
475-
"but kept for testing puroposes."
476-
"We recommend using `transform` instead.",
477-
DeprecationWarning,
478-
stacklevel=2)
479-
480-
offset = self.model.get_offset()
481-
self.model.eval()
482-
X = inputs.cpu().numpy()
483-
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), mode="edge")
484-
X = torch.from_numpy(X).float().to(self.device)
485-
486-
if isinstance(self.model.module, cebra.models.ConvolutionalModelMixin):
487-
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
488-
X = X.transpose(1, 0).unsqueeze(0)
489-
outputs = self.model(X)
490-
491-
# switch back from (1, C, T) -> (T, C)
492-
if isinstance(outputs, torch.Tensor):
493-
assert outputs.dim() == 3 and outputs.shape[0] == 1
494-
outputs = outputs.squeeze(0).transpose(1, 0)
495-
elif isinstance(outputs, tuple):
496-
assert all(tensor.dim() == 3 and tensor.shape[0] == 1
497-
for tensor in outputs)
498-
outputs = (
499-
output.squeeze(0).transpose(1, 0) for output in outputs)
500-
outputs = tuple(outputs)
501-
else:
502-
raise ValueError("Invalid condition in solver.transform")
503-
else:
504-
# Standard evaluation, (T, C, dt)
505-
outputs = self.model(X)
506-
507-
return outputs
508-
509459

510460
@register("supervised-solver-xcebra")
511461
@dataclasses.dataclass

cebra/solver/single_session.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,7 @@ def _select_model(
285285
self._check_is_inputs_valid(inputs, session_id=session_id)
286286

287287
model = self.model.module
288-
if hasattr(model, 'get_offset'):
289-
offset = model.get_offset()
290-
else:
291-
offset = None
288+
offset = model.get_offset()
292289
return model, offset
293290

294291

tests/_utils_deprecated.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import warnings
2+
from typing import Optional, Union
3+
4+
import numpy as np
5+
import numpy.typing as npt
6+
import sklearn.utils.validation as sklearn_utils_validation
7+
import torch
8+
9+
import cebra
10+
import cebra.integrations.sklearn.utils as sklearn_utils
11+
import cebra.models
12+
import cebra.solvers
13+
14+
15+
#NOTE: Deprecated: transform is now handled in the solver but the original
16+
# method is kept here for testing.
17+
def cebra_transform_deprecated(cebra_model,
18+
X: Union[npt.NDArray, torch.Tensor],
19+
session_id: Optional[int] = None) -> npt.NDArray:
20+
"""Transform an input sequence and return the embedding.
21+
22+
Args:
23+
cebra_model: The CEBRA model to use for the transform.
24+
X: A numpy array or torch tensor of size ``time x dimension``.
25+
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
26+
multisession, set to ``None`` for single session.
27+
28+
Returns:
29+
A :py:func:`numpy.array` of size ``time x output_dimension``.
30+
31+
Example:
32+
33+
>>> import cebra
34+
>>> import numpy as np
35+
>>> dataset = np.random.uniform(0, 1, (1000, 30))
36+
>>> cebra_model = cebra.CEBRA(max_iterations=10)
37+
>>> cebra_model.fit(dataset)
38+
CEBRA(max_iterations=10)
39+
>>> embedding = cebra_model.transform(dataset)
40+
41+
"""
42+
warnings.warn(
43+
"The method is deprecated "
44+
"but kept for testing puroposes."
45+
"We recommend using `transform` instead.",
46+
DeprecationWarning,
47+
stacklevel=2)
48+
49+
sklearn_utils_validation.check_is_fitted(cebra_model, "n_features_")
50+
model, offset = cebra_model._select_model(X, session_id)
51+
52+
# Input validation
53+
X = sklearn_utils.check_input_array(X, min_samples=len(cebra_model.offset_))
54+
input_dtype = X.dtype
55+
56+
with torch.no_grad():
57+
model.eval()
58+
59+
if cebra_model.pad_before_transform:
60+
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
61+
mode="edge")
62+
X = torch.from_numpy(X).float().to(cebra_model.device_)
63+
64+
if isinstance(model, cebra.models.ConvolutionalModelMixin):
65+
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
66+
X = X.transpose(1, 0).unsqueeze(0)
67+
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
68+
else:
69+
# Standard evaluation, (T, C, dt)
70+
output = model(X).cpu().numpy()
71+
72+
if input_dtype == "float64":
73+
return output.astype(input_dtype)
74+
75+
return output
76+
77+
78+
# NOTE: Deprecated: batched transform can now be performed (more memory efficient)
79+
# using the transform method of the model, and handling padding is implemented
80+
# directly in the base Solver. This method is kept for testing purposes.
81+
@torch.no_grad()
82+
def multiobjective_transform_deprecated(solver: cebra.solvers.Solver,
83+
inputs: torch.Tensor) -> torch.Tensor:
84+
"""Transform the input data using the model.
85+
86+
Args:
87+
solver: The solver containing the model and device.
88+
inputs: The input data to transform.
89+
90+
Returns:
91+
The transformed data.
92+
"""
93+
94+
warnings.warn(
95+
"The method is deprecated "
96+
"but kept for testing puroposes."
97+
"We recommend using `transform` instead.",
98+
DeprecationWarning,
99+
stacklevel=2)
100+
101+
offset = solver.model.get_offset()
102+
solver.model.eval()
103+
X = inputs.cpu().numpy()
104+
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), mode="edge")
105+
X = torch.from_numpy(X).float().to(solver.device)
106+
107+
if isinstance(solver.model.module, cebra.models.ConvolutionalModelMixin):
108+
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
109+
X = X.transpose(1, 0).unsqueeze(0)
110+
outputs = solver.model(X)
111+
112+
# switch back from (1, C, T) -> (T, C)
113+
if isinstance(outputs, torch.Tensor):
114+
assert outputs.dim() == 3 and outputs.shape[0] == 1
115+
outputs = outputs.squeeze(0).transpose(1, 0)
116+
elif isinstance(outputs, tuple):
117+
assert all(tensor.dim() == 3 and tensor.shape[0] == 1
118+
for tensor in outputs)
119+
outputs = (output.squeeze(0).transpose(1, 0) for output in outputs)
120+
outputs = tuple(outputs)
121+
else:
122+
raise ValueError("Invalid condition in solver.transform")
123+
else:
124+
# Standard evaluation, (T, C, dt)
125+
outputs = solver.model(X)
126+
127+
return outputs

tests/test_integration_xcebra.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pickle
22

3+
import _utils_deprecated
34
import numpy as np
45
import pytest
56
import torch
@@ -173,8 +174,8 @@ def test_synthetic_data_training(synthetic_data, device):
173174
atol=1e-4)
174175

175176
# Test and compare the previous transform (transform_deprecated)
176-
deprecated_transform_embedding = solver.transform_deprecated(
177-
data.neural.to(device))
177+
deprecated_transform_embedding = _utils_deprecated.multiobjective_transform_deprecated(
178+
solver, data.neural.to(device))
178179
assert np.allclose(embedding,
179180
deprecated_transform_embedding,
180181
rtol=1e-4,

0 commit comments

Comments
 (0)