Skip to content

Commit 1d0c498

Browse files
committed
Implement review comments
1 parent 65fc455 commit 1d0c498

File tree

10 files changed

+240
-243
lines changed

10 files changed

+240
-243
lines changed

cebra/data/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,16 @@ def load_batch(self, index: BatchIndex) -> Batch:
193193
"""
194194
raise NotImplementedError()
195195

196-
@abc.abstractmethod
197196
def configure_for(self, model: "cebra.models.Model"):
198197
"""Configure the dataset offset for the provided model.
199198
200199
Call this function before indexing the dataset. This sets the
201-
:py:attr:`offset` attribute of the dataset.
200+
``offset`` attribute of the dataset.
202201
203202
Args:
204203
model: The model to configure the dataset for.
205204
"""
206-
raise NotImplementedError
205+
self.offset = model.get_offset()
207206

208207

209208
@dataclasses.dataclass

cebra/data/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def configure_for(self, model: "Model"):
353353
"""Configure the dataset offset for the provided model.
354354
355355
Call this function before indexing the dataset. This sets the
356-
:py:attr:`offset` attribute of the dataset.
356+
``offset`` attribute of the dataset.
357357
358358
Args:
359359
model: The model to configure the dataset for.

cebra/data/multi_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def configure_for(self, model: "cebra.models.Model"):
108108
"""Configure the dataset offset for the provided model.
109109
110110
Call this function before indexing the dataset. This sets the
111-
:py:attr:`cebra_data.Dataset.offset` attribute of the dataset.
111+
``offset`` attribute of the dataset.
112112
113113
Args:
114114
model: The model to configure the dataset for.

cebra/data/single_session.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,6 @@ def load_batch(self, index: BatchIndex) -> Batch:
6969
reference=self[index.reference],
7070
)
7171

72-
def configure_for(self, model: "cebra.models.Model"):
73-
"""Configure the dataset offset for the provided model.
74-
75-
Call this function before indexing the dataset. This sets the
76-
:py:attr:`cebra_data.Dataset.offset` attribute of the dataset.
77-
78-
Args:
79-
model: The model to configure the dataset for.
80-
"""
81-
self.offset = model.get_offset()
82-
8372

8473
@dataclasses.dataclass
8574
class DiscreteDataLoader(cebra_data.Loader):

cebra/integrations/sklearn/cebra.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""Define the CEBRA model."""
2323

2424
import itertools
25+
import warnings
2526
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
2627
Union)
2728

@@ -129,7 +130,7 @@ def _init_loader(
129130
(not is_cont, not is_disc, is_multi),
130131
]
131132
if any(all(combination) for combination in incompatible_combinations):
132-
raise ValueError(f"Invalid index combination.\n"
133+
raise ValueError("Invalid index combination.\n"
133134
f"Continuous: {is_cont},\n"
134135
f"Discrete: {is_disc},\n"
135136
f"Hybrid training: {is_hybrid},\n"
@@ -293,7 +294,7 @@ def _require_arg(key):
293294
"single-session",
294295
)
295296

296-
error_message = (f"Invalid index combination.\n"
297+
error_message = ("Invalid index combination.\n"
297298
f"Continuous: {is_cont},\n"
298299
f"Discrete: {is_disc},\n"
299300
f"Hybrid training: {is_hybrid},\n"
@@ -340,7 +341,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
340341
if missing_keys:
341342
raise ValueError(
342343
f"Missing keys in data dictionary: {', '.join(missing_keys)}. "
343-
f"You can try loading the CEBRA model with the torch backend.")
344+
"You can try loading the CEBRA model with the torch backend.")
344345

345346
args, state, state_dict = cebra_info['args'], cebra_info[
346347
'state'], cebra_info['state_dict']
@@ -656,12 +657,12 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
656657
# TODO(celia): to make it work for multiple set of index. For now, y should be a tuple of one list only
657658
if isinstance(y, tuple) and len(y) > 1:
658659
raise NotImplementedError(
659-
f"Support for multiple set of index is not implemented in multissesion training, "
660+
"Support for multiple set of index is not implemented in multissesion training, "
660661
f"got {len(y)} sets of indexes.")
661662

662663
if not _are_sessions_equal(X, y):
663664
raise ValueError(
664-
f"Invalid number of sessions: number of sessions in X and y need to match, "
665+
"Invalid number of sessions: number of sessions in X and y need to match, "
665666
f"got X:{len(X)} and y:{[len(y_i) for y_i in y]}.")
666667

667668
for session in range(len(X)):
@@ -685,8 +686,8 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
685686
else:
686687
if not _are_sessions_equal(X, y):
687688
raise ValueError(
688-
f"Invalid number of samples or labels sessions: provide one session for single-session training, "
689-
f"and make sure the number of samples in X and y need match, "
689+
"Invalid number of samples or labels sessions: provide one session for single-session training, "
690+
"and make sure the number of samples in X and y need match, "
690691
f"got {len(X)} and {[len(y_i) for y_i in y]}.")
691692
is_multisession = False
692693
dataset = _get_dataset(X, y)
@@ -848,7 +849,7 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
848849
# Check that same number of index
849850
if len(self.label_types_) != n_idx:
850851
raise ValueError(
851-
f"Number of index invalid: labels must have the same number of index as for fitting,"
852+
"Number of index invalid: labels must have the same number of index as for fitting,"
852853
f"expects {len(self.label_types_)}, got {n_idx} idx.")
853854

854855
for i in range(len(self.label_types_)): # for each index
@@ -861,12 +862,12 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
861862
> 1): # is there more than one feature in the index
862863
if label_types_idx[1][1] != y[i].shape[1]:
863864
raise ValueError(
864-
f"Labels invalid: must have the same number of features as the ones used for fitting,"
865+
"Labels invalid: must have the same number of features as the ones used for fitting,"
865866
f"expects {label_types_idx[1]}, got {y[i].shape}.")
866867

867868
if label_types_idx[0] != y[i].dtype:
868869
raise ValueError(
869-
f"Labels invalid: must have the same type of features as the ones used for fitting,"
870+
"Labels invalid: must have the same type of features as the ones used for fitting,"
870871
f"expects {label_types_idx[0]}, got {y[i].dtype}.")
871872

872873
def _prepare_fit(
@@ -1254,7 +1255,8 @@ def transform(self,
12541255

12551256
return output.detach().cpu().numpy()
12561257

1257-
#NOTE: Deprecated: transform is now handled in the solver but kept for testing.
1258+
#NOTE: Deprecated: transform is now handled in the solver but the original
1259+
# method is kept here for testing.
12581260
def transform_deprecated(self,
12591261
X: Union[npt.NDArray, torch.Tensor],
12601262
session_id: Optional[int] = None) -> npt.NDArray:
@@ -1279,6 +1281,12 @@ def transform_deprecated(self,
12791281
>>> embedding = cebra_model.transform(dataset)
12801282
12811283
"""
1284+
warnings.warn(
1285+
"The method `transform_deprecated` is deprecated "
1286+
"but kept for testing puroposes."
1287+
"We recommend using `transform` instead.",
1288+
DeprecationWarning,
1289+
stacklevel=2)
12821290

12831291
sklearn_utils_validation.check_is_fitted(self, "n_features_")
12841292
model, offset = self._select_model(X, session_id)

cebra/solver/base.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int,
8484
raise ValueError(
8585
f"The batch has length {batch_size_length} which "
8686
f"is smaller or equal than the required offset length {len(offset)}."
87-
f"Either choose a model with smaller offset or the batch should contain more samples."
87+
f"Either choose a model with smaller offset or the batch should contain 3 times more samples."
8888
)
8989

9090

@@ -127,7 +127,7 @@ def _get_batch(inputs: torch.Tensor, offset: Optional[cebra.data.Offset],
127127
inputs: Input data.
128128
offset: Model offset.
129129
batch_start_idx: Index of the first sample in the batch.
130-
batch_end_idx: Index of the first sample in the batch.
130+
batch_end_idx: Index of the last sample in the batch.
131131
pad_before_transform: If True zero-pad the batched data.
132132
133133
Returns:
@@ -237,8 +237,8 @@ def __getitem__(self, idx):
237237

238238
if len(index_dataloader) < 2:
239239
raise ValueError(
240-
f"Number of batches must be greater than 1, you can use transform without batching instead, got {len(index_dataloader)}."
241-
)
240+
f"Number of batches must be greater than 1, you can use transform "
241+
f"without batching instead, got {len(index_dataloader)}.")
242242

243243
output = []
244244
for batch_idx, index_batch in enumerate(index_dataloader):
@@ -253,7 +253,11 @@ def __getitem__(self, idx):
253253
if batch_idx == (len(index_dataloader) - 1):
254254
# last batch, incomplete
255255
index_batch = torch.cat((last_batch, index_batch), dim=0)
256+
assert index_batch[-1] + 1 == len(inputs), (
257+
f"Last batch index {index_batch[-1]} + 1 should be equal to the length of inputs {len(inputs)}."
258+
)
256259

260+
# Batch start and end so that `batch_size` size with the last batch including 2 batches
257261
batch_start_idx, batch_end_idx = index_batch[0], index_batch[-1] + 1
258262
batched_data = _get_batch(inputs=inputs,
259263
offset=offset,
@@ -264,7 +268,7 @@ def __getitem__(self, idx):
264268
output_batch = _inference_transform(model, batched_data)
265269
output.append(output_batch)
266270

267-
output = torch.cat(output)
271+
output = torch.cat(output, dim=0)
268272
return output
269273

270274

@@ -608,7 +612,7 @@ def transform(self,
608612
of the given model, after switching it into eval mode.
609613
610614
Args:
611-
inputs: The input signal
615+
inputs: The input signal (T, N).
612616
pad_before_transform: If ``False``, no padding is applied to the input
613617
sequence and the output sequence will be smaller than the input
614618
sequence due to the receptive field of the model. If the
@@ -635,11 +639,14 @@ def transform(self,
635639

636640
model, offset = self._select_model(inputs, session_id)
637641

638-
if len(offset) < 2 and pad_before_transform:
639-
pad_before_transform = False
642+
#if len(offset) < 2 and pad_before_transform:
643+
# pad_before_transform = False
640644

641645
model.eval()
642-
if batch_size is not None and inputs.shape[0] > int(batch_size * 2):
646+
if batch_size is not None and inputs.shape[0] > int(
647+
batch_size * 2) and not isinstance(
648+
self.model, cebra.models.ResampleModelMixin):
649+
# NOTE: resampling models are not supported for batched inference.
643650
output = _batched_transform(
644651
model=model,
645652
inputs=inputs,

cebra/solver/multiobjective.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def finalize(self):
155155
if len(set(self.feature_ranges_tuple)) != len(
156156
self.feature_ranges_tuple):
157157
raise RuntimeError(
158-
f"Feature ranges are not unique. Please check again and remove the duplicates. "
158+
"Feature ranges are not unique. Please check again and remove the duplicates. "
159159
f"Feature ranges: {self.feature_ranges_tuple}")
160160

161161
print("Creating MultiCriterion")
@@ -456,8 +456,27 @@ def validation(
456456
self.log.setdefault(("sum_loss_val",), []).append(sum_loss_valid)
457457
return stats_val
458458

459+
# NOTE: Deprecated: batched transform can now be performed (more memory efficient)
460+
# using the transform method of the model, and handling padding is implemented
461+
# directly in the base Solver. This method is kept for testing purposes.
459462
@torch.no_grad()
460463
def transform_deprecated(self, inputs: torch.Tensor) -> torch.Tensor:
464+
"""Transform the input data using the model.
465+
466+
Args:
467+
inputs: The input data to transform.
468+
469+
Returns:
470+
The transformed data.
471+
"""
472+
473+
warnings.warn(
474+
"The method `transform_deprecated` is deprecated "
475+
"but kept for testing puroposes."
476+
"We recommend using `transform` instead.",
477+
DeprecationWarning,
478+
stacklevel=2)
479+
461480
offset = self.model.get_offset()
462481
self.model.eval()
463482
X = inputs.cpu().numpy()

tests/test_integration_xcebra.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,32 @@ def test_synthetic_data_training(synthetic_data, device):
158158
assert transform_embedding.shape[
159159
1] == n_latents, "Incorrect embedding dimension"
160160
assert not torch.isnan(transform_embedding).any(), "NaN values in embedding"
161-
assert np.allclose(embedding, transform_embedding, rtol=1e-02)
161+
assert np.allclose(embedding, transform_embedding, rtol=1e-4, atol=1e-4)
162162

163163
# Test the transform with batching
164164
batched_embedding = solver.transform(data.neural.to(device), batch_size=512)
165165
assert batched_embedding.shape[
166166
1] == n_latents, "Incorrect embedding dimension"
167167
assert not torch.isnan(batched_embedding).any(), "NaN values in embedding"
168-
assert np.allclose(embedding, batched_embedding, rtol=1e-02)
169-
170-
assert np.allclose(transform_embedding, batched_embedding, rtol=1e-02)
168+
assert np.allclose(embedding, batched_embedding, rtol=1e-4, atol=1e-4)
169+
170+
assert np.allclose(transform_embedding,
171+
batched_embedding,
172+
rtol=1e-4,
173+
atol=1e-4)
174+
175+
# Test and compare the previous transform (transform_deprecated)
176+
deprecated_transform_embedding = solver.transform_deprecated(
177+
data.neural.to(device))
178+
assert np.allclose(embedding,
179+
deprecated_transform_embedding,
180+
rtol=1e-4,
181+
atol=1e-4)
182+
assert np.allclose(transform_embedding,
183+
deprecated_transform_embedding,
184+
rtol=1e-4,
185+
atol=1e-4)
186+
assert np.allclose(batched_embedding,
187+
deprecated_transform_embedding,
188+
rtol=1e-4,
189+
atol=1e-4)

0 commit comments

Comments
 (0)