Skip to content

Commit 4632c04

Browse files
committed
Add select_model to aux solvers
1 parent c9fa5c8 commit 4632c04

File tree

3 files changed

+180
-3
lines changed

3 files changed

+180
-3
lines changed

cebra/data/multi_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def configure_for(self, model: "cebra.models.Model"):
109109
"""Configure the dataset offset for the provided model.
110110
111111
Call this function before indexing the dataset. This sets the
112-
:py:attr:`cebra.data.Dataset.offset` attribute of the dataset.
112+
:py:attr:`cebra_data.Dataset.offset` attribute of the dataset.
113113
114114
Args:
115115
model: The model to configure the dataset for.

cebra/solver/multi_session.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
#
2222
"""Solver implementations for multi-session datasetes."""
2323

24-
from typing import List, Optional
24+
import copy
25+
from typing import List, Optional, Tuple, Union
2526

27+
import numpy.typing as npt
2628
import torch
2729

2830
import cebra
@@ -241,7 +243,16 @@ class MultiSessionAuxVariableSolver(MultiSessionSolver):
241243
"""Multi session training, contrasting neural data against behavior."""
242244

243245
_variant_name = "multi-session-aux"
244-
reference_model: torch.nn.Module
246+
reference_model: torch.nn.Module = None
247+
248+
def __post_init__(self):
249+
super().__post_init__()
250+
if self.reference_model is None:
251+
# NOTE(stes): This should work, according to this thread
252+
# https://discuss.pytorch.org/t/can-i-deepcopy-a-model/52192/19
253+
# and create a true copy of the model.
254+
self.reference_model = copy.deepcopy(self.model)
255+
self.reference_model.to(self.device)
245256

246257
def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch:
247258
"""Given batches of input examples, computes the feature representations/embeddings.
@@ -276,3 +287,83 @@ def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch:
276287
positive=pos.view(-1, num_features),
277288
negative=neg.view(-1, num_features),
278289
)
290+
291+
def _select_model(
292+
self,
293+
inputs: Union[torch.Tensor, List[torch.Tensor]],
294+
session_id: Optional[int] = None,
295+
use_reference_model: bool = False,
296+
) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module],
297+
cebra.data.datatypes.Offset]:
298+
""" Select the model based on the input dimension and session ID.
299+
300+
Args:
301+
inputs: Data to infer using the selected model.
302+
session_id: The session ID, an :py:class:`int` between 0 and
303+
the number of sessions -1 for multisession, and set to
304+
``None`` for single session.
305+
306+
Returns:
307+
The model (first returns) and the offset of the model (second returns).
308+
"""
309+
self._check_is_inputs_valid(inputs, session_id=session_id)
310+
self._check_is_session_id_valid(session_id=session_id)
311+
312+
if use_reference_model:
313+
model = self.reference_model[session_id]
314+
else:
315+
model = self.model[session_id]
316+
offset = model.get_offset()
317+
return model, offset
318+
319+
@torch.no_grad()
320+
def transform(self,
321+
inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray],
322+
pad_before_transform: bool = True,
323+
session_id: Optional[int] = None,
324+
batch_size: Optional[int] = None,
325+
use_reference_model: bool = False) -> torch.Tensor:
326+
"""Compute the embedding.
327+
This function by default use ``model`` that was trained to encode the positive
328+
and negative samples. To use ``reference_model`` instead of ``model``
329+
``use_reference_model`` should be equal ``True``.
330+
Args:
331+
inputs: The input signal
332+
use_reference_model: Flag for using ``reference_model``
333+
Returns:
334+
The output embedding.
335+
"""
336+
if isinstance(inputs, list):
337+
raise NotImplementedError(
338+
"Inputs to transform() should be the data for a single session."
339+
)
340+
elif not isinstance(inputs, torch.Tensor):
341+
raise ValueError(
342+
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
343+
344+
if not hasattr(self, "history") and len(self.history) > 0:
345+
raise ValueError(
346+
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
347+
"appropriate arguments before using this estimator.")
348+
model, offset = self._select_model(
349+
inputs, session_id, use_reference_model=use_reference_model)
350+
351+
if len(offset) < 2 and pad_before_transform:
352+
pad_before_transform = False
353+
354+
model.eval()
355+
if batch_size is not None:
356+
output = abc_._batched_transform(
357+
model=model,
358+
inputs=inputs,
359+
offset=offset,
360+
batch_size=batch_size,
361+
pad_before_transform=pad_before_transform,
362+
)
363+
else:
364+
output = abc_._transform(model=model,
365+
inputs=inputs,
366+
offset=offset,
367+
pad_before_transform=pad_before_transform)
368+
369+
return output

cebra/solver/single_session.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import List, Optional, Tuple, Union
2626

2727
import literate_dataclasses as dataclasses
28+
import numpy.typing as npt
2829
import torch
2930

3031
import cebra
@@ -206,6 +207,91 @@ def __post_init__(self):
206207
self.reference_model = copy.deepcopy(self.model)
207208
self.reference_model.to(self.model.device)
208209

210+
def _select_model(
211+
self,
212+
inputs: Union[torch.Tensor, List[torch.Tensor]],
213+
session_id: Optional[int] = None,
214+
use_reference_model: bool = False,
215+
) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module],
216+
cebra.data.datatypes.Offset]:
217+
""" Select the model based on the input dimension and session ID.
218+
219+
Args:
220+
inputs: Data to infer using the selected model.
221+
session_id: The session ID, an :py:class:`int` between 0 and
222+
the number of sessions -1 for multisession, and set to
223+
``None`` for single session.
224+
use_reference_model: Flag for using ``reference_model``.
225+
226+
Returns:
227+
The model (first returns) and the offset of the model (second returns).
228+
"""
229+
self._check_is_inputs_valid(inputs, session_id=session_id)
230+
self._check_is_session_id_valid(session_id=session_id)
231+
232+
if use_reference_model:
233+
model = self.reference_model
234+
else:
235+
model = self.model
236+
237+
if hasattr(model, 'get_offset'):
238+
offset = model.get_offset()
239+
else:
240+
offset = None
241+
return model, offset
242+
243+
@torch.no_grad()
244+
def transform(self,
245+
inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray],
246+
pad_before_transform: bool = True,
247+
session_id: Optional[int] = None,
248+
batch_size: Optional[int] = None,
249+
use_reference_model: bool = False) -> torch.Tensor:
250+
"""Compute the embedding.
251+
This function by default use ``model`` that was trained to encode the positive
252+
and negative samples. To use ``reference_model`` instead of ``model``
253+
``use_reference_model`` should be equal ``True``.
254+
Args:
255+
inputs: The input signal
256+
use_reference_model: Flag for using ``reference_model``
257+
Returns:
258+
The output embedding.
259+
"""
260+
if isinstance(inputs, list):
261+
raise NotImplementedError(
262+
"Inputs to transform() should be the data for a single session."
263+
)
264+
elif not isinstance(inputs, torch.Tensor):
265+
raise ValueError(
266+
f"Inputs should be a torch.Tensor, not {type(inputs)}.")
267+
268+
if not hasattr(self, "history") and len(self.history) > 0:
269+
raise ValueError(
270+
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
271+
"appropriate arguments before using this estimator.")
272+
model, offset = self._select_model(
273+
inputs, session_id, use_reference_model=use_reference_model)
274+
275+
if len(offset) < 2 and pad_before_transform:
276+
pad_before_transform = False
277+
278+
model.eval()
279+
if batch_size is not None:
280+
output = abc_._batched_transform(
281+
model=model,
282+
inputs=inputs,
283+
offset=offset,
284+
batch_size=batch_size,
285+
pad_before_transform=pad_before_transform,
286+
)
287+
else:
288+
output = abc_._transform(model=model,
289+
inputs=inputs,
290+
offset=offset,
291+
pad_before_transform=pad_before_transform)
292+
293+
return output
294+
209295
def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch:
210296
"""Given a batch of input examples, computes the feature representation/embedding.
211297

0 commit comments

Comments
 (0)