Skip to content

Commit 65fc455

Browse files
committed
Add some docstrings and typings and clean unnecessary changes
1 parent 9875a38 commit 65fc455

File tree

5 files changed

+116
-29
lines changed

5 files changed

+116
-29
lines changed

cebra/data/single_session.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,6 @@ def __post_init__(self):
371371
self._init_behavior_distribution()
372372
self._init_time_distribution()
373373

374-
if self.conditional != "time_delta":
375-
raise NotImplementedError(
376-
"Hybrid training is currently only implemented using the ``time_delta`` "
377-
"continual distribution.")
378-
379374
def _init_behavior_distribution(self):
380375
if self.conditional == "time":
381376
self.behavior_distribution = cebra.distributions.TimeContrastive(

cebra/integrations/sklearn/cebra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def transform(self,
12271227
>>> cebra_model = cebra.CEBRA(max_iterations=10)
12281228
>>> cebra_model.fit(dataset)
12291229
CEBRA(max_iterations=10)
1230-
>>> embedding = cebra_model.transform(dataset)
1230+
>>> embedding = cebra_model.transform(dataset, batch_size=200)
12311231
12321232
"""
12331233
sklearn_utils_validation.check_is_fitted(self, "n_features_")
@@ -1254,7 +1254,7 @@ def transform(self,
12541254

12551255
return output.detach().cpu().numpy()
12561256

1257-
#NOTE: Deprecated, as transform is now handled in the solver but kept for testing.
1257+
#NOTE: Deprecated: transform is now handled in the solver but kept for testing.
12581258
def transform_deprecated(self,
12591259
X: Union[npt.NDArray, torch.Tensor],
12601260
session_id: Optional[int] = None) -> npt.NDArray:

cebra/solver/base.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int,
5454
offset: cebra.data.Offset, num_samples: int):
5555
"""Check that indexes in a batch are in a correct range.
5656
57-
First and last index must be positive integers, smaller than the total length of inputs
58-
in the dataset, the first index must be smaller than the last and the batch size cannot
59-
be smaller than the offset of the model.
57+
First and last index must be positive integers, smaller than
58+
the total length of inputs in the dataset, the first index
59+
must be smaller than the last and the batch size cannot be
60+
smaller than the offset of the model.
6061
6162
Args:
6263
batch_start_idx: Index of the first sample in the batch.
@@ -380,6 +381,16 @@ def num_parameters(self) -> int:
380381

381382
@abc.abstractmethod
382383
def parameters(self, session_id: Optional[int] = None):
384+
"""Iterate over all parameters of the model.
385+
386+
Args:
387+
session_id: The session ID, an :py:class:`int` between 0 and
388+
the number of sessions -1 for multisession, and set to
389+
``None`` for single session.
390+
391+
Yields:
392+
The parameters of the model.
393+
"""
383394
raise NotImplementedError
384395

385396
def _get_loader(self, loader):
@@ -573,6 +584,13 @@ def _select_model(
573584
raise NotImplementedError
574585

575586
def _check_is_fitted(self):
587+
"""Check if the model is fitted.
588+
589+
If the model is fitted, the solver should have a `n_features` attribute.
590+
591+
Raises:
592+
ValueError: If the model is not fitted.
593+
"""
576594
if not hasattr(self, "n_features"):
577595
raise ValueError(
578596
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
@@ -581,7 +599,7 @@ def _check_is_fitted(self):
581599
@torch.no_grad()
582600
def transform(self,
583601
inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray],
584-
pad_before_transform: bool = True,
602+
pad_before_transform: Optional[bool] = True,
585603
session_id: Optional[int] = None,
586604
batch_size: Optional[int] = None) -> torch.Tensor:
587605
"""Compute the embedding.
@@ -591,11 +609,12 @@ def transform(self,
591609
592610
Args:
593611
inputs: The input signal
594-
pad_before_transform: If ``False``, no padding is applied to the input sequence.
595-
and the output sequence will be smaller than the input sequence due to the
596-
receptive field of the model. If the input sequence is ``n`` steps long,
597-
and a model with receptive field ``m`` is used, the output sequence would
598-
only be ``n-m+1`` steps long.
612+
pad_before_transform: If ``False``, no padding is applied to the input
613+
sequence and the output sequence will be smaller than the input
614+
sequence due to the receptive field of the model. If the
615+
input sequence is ``n`` steps long, and a model with receptive
616+
field ``m`` is used, the output sequence would only be
617+
``n-m+1`` steps long.
599618
session_id: The session ID, an :py:class:`int` between 0 and
600619
the number of sessions -1 for multisession, and set to
601620
``None`` for single session.
@@ -640,8 +659,6 @@ def transform(self,
640659
def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch:
641660
"""Given a batch of input examples, return the model outputs.
642661
643-
TODO: make this a public function?
644-
645662
Args:
646663
batch: The input data, not necessarily aligned across the batch
647664
dimension. This means that ``batch.index`` specifies the map
@@ -654,12 +671,12 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch:
654671
"""
655672
raise NotImplementedError
656673

657-
def load(self, logdir, filename="checkpoint.pth"):
674+
def load(self, logdir: str, filename: str = "checkpoint.pth"):
658675
"""Load the experiment from its checkpoint file.
659676
660677
Args:
661-
logdir: Log directory.
662-
filename (str): Checkpoint name for loading the experiment.
678+
logdir: Logging directory.
679+
filename: Checkpoint name for loading the experiment.
663680
"""
664681

665682
savepath = os.path.join(logdir, filename)
@@ -674,7 +691,7 @@ def load(self, logdir, filename="checkpoint.pth"):
674691
session_n_features for session_n_features in n_features
675692
] if isinstance(n_features, list) else n_features)
676693

677-
def save(self, logdir, filename="checkpoint_last.pth"):
694+
def save(self, logdir: str, filename: str = "checkpoint_last.pth"):
678695
"""Save the model and optimizer params.
679696
680697
Args:

cebra/solver/multi_session.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,21 @@ class MultiSessionSolver(abc_.Solver):
4040
_variant_name = "multi-session"
4141

4242
def parameters(self, session_id: Optional[int] = None):
43-
"""Iterate over all parameters."""
43+
"""Iterate over all parameters.
44+
45+
Args:
46+
session_id: The session ID, an :py:class:`int` between 0 and
47+
the number of sessions -1 for multisession, and set to
48+
``None`` for single session.
49+
50+
Yields:
51+
The parameters of the model.
52+
"""
4453
if session_id is not None:
4554
for parameter in self.model[session_id].parameters():
4655
yield parameter
4756

57+
# If session_id is None, it can still iterate over the criterion
4858
for parameter in self.criterion.parameters():
4959
yield parameter
5060

@@ -161,12 +171,12 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor,
161171
def _check_is_session_id_valid(self, session_id: Optional[int]):
162172
"""Check that the session ID provided is valid for the solver instance.
163173
164-
The session ID must be non-null and between 0 and the number session in the dataset.
174+
The session ID must be non-null and between 0 and the number session
175+
in the dataset.
165176
166177
Args:
167178
session_id: The session ID to check.
168179
"""
169-
170180
if session_id is None:
171181
raise RuntimeError(
172182
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."
@@ -233,7 +243,20 @@ class MultiSessionAuxVariableSolver(MultiSessionSolver):
233243
_variant_name = "multi-session-aux"
234244
reference_model: torch.nn.Module
235245

236-
def _inference(self, batches):
246+
def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch:
247+
"""Given batches of input examples, computes the feature representations/embeddings.
248+
249+
Args:
250+
batches: A list of input data, not necessarily aligned across the batch
251+
dimension. This means that ``batch.index`` specifies the map
252+
between reference/positive samples, if not equal ``None``.
253+
254+
Returns:
255+
Processed batch of data. While the input data might not be aligned
256+
across the sample dimensions, the output data should be aligned and
257+
``batch.index`` should be set to ``None``.
258+
259+
"""
237260
refs = []
238261
poss = []
239262
negs = []

cebra/solver/single_session.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,18 @@ class SingleSessionSolver(abc_.Solver):
4646
_variant_name = "single-session"
4747

4848
def parameters(self, session_id: Optional[int] = None):
49-
"""Iterate over all parameters."""
50-
self._check_is_session_id_valid(session_id=session_id)
49+
"""Iterate over all parameters.
50+
51+
Args:
52+
session_id: The session ID, an :py:class:`int` between 0 and
53+
the number of sessions -1 for multisession, and set to
54+
``None`` for single session.
55+
56+
Yields:
57+
The parameters of the model.
58+
"""
59+
# If session_id is invalid, it doesn't matter, since we are
60+
# using a single session solver.
5161
for parameter in self.model.parameters():
5262
yield parameter
5363

@@ -196,7 +206,22 @@ def __post_init__(self):
196206
self.reference_model = copy.deepcopy(self.model)
197207
self.reference_model.to(self.model.device)
198208

199-
def _inference(self, batch):
209+
def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch:
210+
"""Given a batch of input examples, computes the feature representation/embedding.
211+
212+
The reference samples are processed with a different model than the
213+
positive and negative samples.
214+
215+
Args:
216+
batch: The input data, not necessarily aligned across the batch
217+
dimension. This means that ``batch.index`` specifies the map
218+
between reference/positive samples, if not equal ``None``.
219+
220+
Returns:
221+
Processed batch of data. While the input data might not be aligned
222+
across the sample dimensions, the output data should be aligned and
223+
``batch.index`` should be set to ``None``.
224+
"""
200225
batch.to(self.device)
201226
ref = self.reference_model(batch.reference)
202227
pos = self.model(batch.positive)
@@ -212,6 +237,21 @@ class SingleSessionHybridSolver(abc_.MultiobjectiveSolver, SingleSessionSolver):
212237
_variant_name = "single-session-hybrid"
213238

214239
def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch:
240+
"""Given a batch of input examples, computes the feature representation/embedding.
241+
242+
The samples are processed with both a time-contrastive module and a
243+
behavior-contrastive module, that are part of the same model.
244+
245+
Args:
246+
batch: The input data, not necessarily aligned across the batch
247+
dimension. This means that ``batch.index`` specifies the map
248+
between reference/positive samples, if not equal ``None``.
249+
250+
Returns:
251+
Processed batch of data. While the input data might not be aligned
252+
across the sample dimensions, the output data should be aligned and
253+
``batch.index`` should be set to ``None``.
254+
"""
215255
batch.to(self.device)
216256
behavior_ref = self.model(batch.reference)[0]
217257
behavior_pos = self.model(batch.positive[:int(len(batch.positive) //
@@ -305,6 +345,18 @@ def get_embedding(self, data):
305345
return self.model(data[0].T)
306346

307347
def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch:
348+
"""Given a batch of input examples, computes the feature representation/embedding.
349+
350+
Args:
351+
batch: The input data, not necessarily aligned across the batch
352+
dimension. This means that ``batch.index`` specifies the map
353+
between reference/positive samples, if not equal ``None``.
354+
355+
Returns:
356+
Processed batch of data. While the input data might not be aligned
357+
across the sample dimensions, the output data should be aligned and
358+
``batch.index`` should be set to ``None``.
359+
"""
308360
outputs = self.get_embedding(self.neural)
309361
idc = batch.positive - self.offset.left >= len(outputs)
310362
batch.positive[idc] = batch.reference[idc]

0 commit comments

Comments
 (0)