Skip to content

Commit 83c1669

Browse files
CeliaBenquetstes
authored andcommitted
Add tests to solver
1 parent b417a23 commit 83c1669

File tree

7 files changed

+458
-266
lines changed

7 files changed

+458
-266
lines changed

cebra/data/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def load_batch(self, index: BatchIndex) -> Batch:
196196
"""
197197
raise NotImplementedError()
198198

199+
@abc.abstractmethod
199200
def configure_for(self, model: "cebra.models.Model"):
200201
"""Configure the dataset offset for the provided model.
201202
@@ -205,6 +206,7 @@ def configure_for(self, model: "cebra.models.Model"):
205206
Args:
206207
model: The model to configure the dataset for.
207208
"""
209+
raise NotImplementedError
208210
self.offset = model.get_offset()
209211

210212

@@ -230,6 +232,8 @@ class Loader(abc.ABC, cebra.io.HasDevice):
230232
doc="""A dataset instance specifying a ``__getitem__`` function.""",
231233
)
232234

235+
time_offset: int = dataclasses.field(default=10)
236+
233237
num_steps: int = dataclasses.field(
234238
default=None,
235239
doc=

cebra/data/multi_session.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,18 @@ def configure_for(self, model):
111111
for session in self.iter_sessions():
112112
session.configure_for(model)
113113

114+
def configure_for(self, model: "cebra.models.Model"):
115+
"""Configure the dataset offset for the provided model.
116+
117+
Call this function before indexing the dataset. This sets the
118+
:py:attr:`offset` attribute of the dataset.
119+
120+
Args:
121+
model: The model to configure the dataset for.
122+
"""
123+
for i, session in enumerate(self.iter_sessions()):
124+
session.configure_for(model[i])
125+
114126

115127
@dataclasses.dataclass
116128
class MultiSessionLoader(cebra_data.Loader):
@@ -121,8 +133,6 @@ class MultiSessionLoader(cebra_data.Loader):
121133
dimension, it is better to use a :py:class:`cebra.data.single_session.MixedDataLoader`.
122134
"""
123135

124-
time_offset: int = dataclasses.field(default=10)
125-
126136
def __post_init__(self):
127137
super().__post_init__()
128138
self.sampler = cebra_distr.MultisessionSampler(self.dataset,
@@ -151,7 +161,6 @@ class ContinuousMultiSessionDataLoader(MultiSessionLoader):
151161
"""Contrastive learning conditioned on a continuous behavior variable."""
152162

153163
conditional: str = "time_delta"
154-
time_offset: int = dataclasses.field(default=10)
155164

156165
@property
157166
def index(self):

cebra/data/single_session.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ def load_batch(self, index: BatchIndex) -> Batch:
7272
reference=self[index.reference],
7373
)
7474

75+
def configure_for(self, model: "cebra.models.Model"):
76+
"""Configure the dataset offset for the provided model.
77+
78+
Call this function before indexing the dataset. This sets the
79+
:py:attr:`offset` attribute of the dataset.
80+
81+
Args:
82+
model: The model to configure the dataset for.
83+
"""
84+
self.offset = model.get_offset()
85+
7586

7687
@dataclasses.dataclass
7788
class DiscreteDataLoader(cebra_data.Loader):
@@ -192,7 +203,6 @@ class ContinuousDataLoader(cebra_data.Loader):
192203
and become equivalent to time contrastive learning.
193204
""",
194205
)
195-
time_offset: int = dataclasses.field(default=10)
196206
delta: float = dataclasses.field(default=0.1)
197207

198208
def __post_init__(self):
@@ -274,7 +284,6 @@ class MixedDataLoader(cebra_data.Loader):
274284
"""
275285

276286
conditional: str = dataclasses.field(default="time_delta")
277-
time_offset: int = dataclasses.field(default=10)
278287

279288
@property
280289
def dindex(self):
@@ -337,7 +346,6 @@ class HybridDataLoader(cebra_data.Loader):
337346
"""
338347

339348
conditional: str = dataclasses.field(default="time_delta")
340-
time_offset: int = dataclasses.field(default=10)
341349
delta: float = dataclasses.field(default=0.1)
342350

343351
@property

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,6 @@ def _configure_for_all(
776776
f"receptive fields/offsets larger than 1 via the sklearn API. "
777777
f"Please use a different model, or revert to the pytorch "
778778
f"API for training.")
779-
780-
d.configure_for(model[n])
781779
else:
782780
if not isinstance(model, cebra.models.ConvolutionalModelMixin):
783781
if len(model.get_offset()) > 1:
@@ -787,7 +785,7 @@ def _configure_for_all(
787785
f"Please use a different model, or revert to the pytorch "
788786
f"API for training.")
789787

790-
dataset.configure_for(model)
788+
dataset.configure_for(model)
791789

792790
def _select_model(self, X: Union[npt.NDArray, torch.Tensor],
793791
session_id: int):

cebra/solver/base.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import literate_dataclasses as dataclasses
3939
import numpy as np
40+
import numpy.typing as npt
4041
import torch
4142
import torch.nn.functional as F
4243
import tqdm
@@ -89,32 +90,6 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int,
8990
)
9091

9192

92-
def _get_batch(inputs: torch.Tensor, offset: cebra.data.Offset,
93-
batch_start_idx: int, batch_end_idx: int) -> torch.Tensor:
94-
"""Get a batch of samples between the `batch_start_idx` and `batch_end_idx`.
95-
96-
Args:
97-
inputs: Input data.
98-
offset: Model offset.
99-
batch_start_idx: Index of the first sample in the batch.
100-
batch_end_idx: Index of the first sample in the batch.
101-
102-
Returns:
103-
The batch.
104-
"""
105-
106-
if batch_start_idx == 0: # First batch
107-
indices = batch_start_idx, (batch_end_idx + offset.right - 1)
108-
elif batch_end_idx == len(inputs): # Last batch
109-
indices = (batch_start_idx - offset.left), batch_end_idx
110-
else:
111-
indices = batch_start_idx - offset.left, batch_end_idx + offset.right - 1
112-
113-
_check_indices(indices[0], indices[1], offset, len(inputs))
114-
batched_data = inputs[slice(*indices)]
115-
return batched_data
116-
117-
11893
def _add_batched_zero_padding(batched_data: torch.Tensor,
11994
offset: cebra.data.Offset, batch_start_idx: int,
12095
batch_end_idx: int,
@@ -145,6 +120,45 @@ def _add_batched_zero_padding(batched_data: torch.Tensor,
145120
return batched_data
146121

147122

123+
def _get_batch(inputs: torch.Tensor, offset: Optional[cebra.data.Offset],
124+
batch_start_idx: int, batch_end_idx: int,
125+
pad_before_transform: bool) -> torch.Tensor:
126+
"""Get a batch of samples between the `batch_start_idx` and `batch_end_idx`.
127+
128+
Args:
129+
inputs: Input data.
130+
offset: Model offset.
131+
batch_start_idx: Index of the first sample in the batch.
132+
batch_end_idx: Index of the first sample in the batch.
133+
pad_before_transform: If True zero-pad the batched data.
134+
135+
Returns:
136+
The batch.
137+
"""
138+
if offset is None:
139+
raise ValueError(f"offset cannot be null.")
140+
141+
if batch_start_idx == 0: # First batch
142+
indices = batch_start_idx, (batch_end_idx + offset.right - 1)
143+
elif batch_end_idx == len(inputs): # Last batch
144+
indices = (batch_start_idx - offset.left), batch_end_idx
145+
else:
146+
indices = batch_start_idx - offset.left, batch_end_idx + offset.right - 1
147+
148+
_check_indices(indices[0], indices[1], offset, len(inputs))
149+
batched_data = inputs[slice(*indices)]
150+
151+
if pad_before_transform:
152+
batched_data = _add_batched_zero_padding(
153+
batched_data=batched_data,
154+
offset=offset,
155+
batch_start_idx=batch_start_idx,
156+
batch_end_idx=batch_end_idx,
157+
num_samples=len(inputs))
158+
159+
return batched_data
160+
161+
148162
def _inference_transform(model: cebra.models.Model,
149163
inputs: torch.Tensor) -> torch.Tensor:
150164
"""Compute the embedding on the inputs using the model provided.
@@ -156,9 +170,7 @@ def _inference_transform(model: cebra.models.Model,
156170
Returns:
157171
The embedding.
158172
"""
159-
#TODO(rodrigo): I am not sure what is the best way with dealing with the types and
160-
# device when using batched inference. This works for now.
161-
inputs = inputs.type(torch.FloatTensor).to(next(model.parameters()).device)
173+
inputs = inputs.float().to(next(model.parameters()).device)
162174

163175
if isinstance(model, cebra.models.ConvolutionalModelMixin):
164176
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
@@ -228,15 +240,8 @@ def __getitem__(self, idx):
228240
batched_data = _get_batch(inputs=inputs,
229241
offset=offset,
230242
batch_start_idx=batch_start_idx,
231-
batch_end_idx=batch_end_idx)
232-
233-
if pad_before_transform:
234-
batched_data = _add_batched_zero_padding(
235-
batched_data=batched_data,
236-
offset=offset,
237-
batch_start_idx=batch_start_idx,
238-
batch_end_idx=batch_end_idx,
239-
num_samples=len(inputs))
243+
batch_end_idx=batch_end_idx,
244+
pad_before_transform=pad_before_transform)
240245

241246
output_batch = _inference_transform(model, batched_data)
242247
output.append(output_batch)
@@ -549,6 +554,15 @@ def transform(self,
549554
Returns:
550555
The output embedding.
551556
"""
557+
if isinstance(inputs, list):
558+
raise NotImplementedError(
559+
"Inputs to transform() should be the data for a single session."
560+
)
561+
562+
elif not isinstance(inputs, torch.Tensor):
563+
raise ValueError(
564+
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
565+
552566
if not hasattr(self, "n_features"):
553567
raise ValueError(
554568
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "

cebra/solver/single_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,10 @@ def _select_model(
227227
self._check_is_session_id_valid(session_id=session_id)
228228

229229
model = self.model.module
230-
offset = model.get_offset()
230+
if hasattr(model, 'get_offset'):
231+
offset = model.get_offset()
232+
else:
233+
offset = None
231234
return model, offset
232235

233236

0 commit comments

Comments
 (0)