Skip to content

Commit 32fae46

Browse files
committed
Adapt unified code to get_model method
1 parent 619a662 commit 32fae46

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

cebra/solver/single_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""Single session solvers embed a single pair of time series."""
2323

2424
import copy
25-
from typing import Optional
25+
from typing import Optional, Tuple
2626

2727
import literate_dataclasses as dataclasses
2828
import torch

cebra/solver/unified_session.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#
2222
"""Solver implementations for unified-session datasets."""
2323

24-
from typing import List, Optional, Tuple, Union
24+
from typing import List, Optional, Union
2525

2626
import literate_dataclasses as dataclasses
2727
import numpy as np
@@ -93,9 +93,10 @@ def _check_is_inputs_valid(self, inputs: Union[torch.Tensor,
9393
f"(n_samples, {self.n_features}), got (n_samples, {inputs.shape[1]})."
9494
)
9595

96-
def _check_is_session_id_valid(self,
97-
session_id: Optional[int] = None
98-
): # same as multi
96+
def _check_is_session_id_valid(
97+
self,
98+
session_id: Optional[int] = None,
99+
): # same as multi
99100
"""Check that the session ID provided is valid for the solver instance.
100101
101102
The session ID must be non-null and between 0 and the number session in the dataset.
@@ -106,33 +107,27 @@ def _check_is_session_id_valid(self,
106107

107108
if session_id is None:
108109
raise RuntimeError(
109-
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."
110+
"No session_id provided: unified model requires a session_id as the target session to use to align the sessions."
110111
)
111112
if session_id >= self.num_sessions or session_id < 0:
112113
raise RuntimeError(
113-
f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}."
114+
f"Invalid session_id {session_id}: session_id for the current unified model must be between 0 and {self.num_sessions-1}."
114115
)
115116

116-
def _select_model(
117-
self,
118-
inputs: Union[torch.Tensor, List[torch.Tensor]],
119-
session_id: Optional[int] = None
120-
) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module],
121-
cebra.data.datatypes.Offset]:
122-
""" Select the model based on the input dimension and session ID.
117+
def _get_model(self, session_id: Optional[int] = None):
118+
"""Get the model for the given session ID.
123119
124120
Args:
125-
inputs: Data to infer using the selected model.
126121
session_id: The session ID, an :py:class:`int` between 0 and
127122
the number of sessions -1 for multisession, and set to
128123
``None`` for single session.
129124
130125
Returns:
131-
The model (first returns) and the offset of the model (second returns).
126+
The model for the given session ID.
132127
"""
133-
model = self.model
134-
offset = model.get_offset()
135-
return model, offset
128+
self._check_is_session_id_valid(session_id=session_id)
129+
self._check_is_fitted()
130+
return self.model
136131

137132
def _single_model_inference(self, batch: cebra.data.Batch,
138133
model: torch.nn.Module) -> cebra.data.Batch:
@@ -249,13 +244,26 @@ def transform(self,
249244
ref_idx=torch.arange(batch_start, batch_end),
250245
session_id=session_id).to(self.device)
251246

252-
refs_data_batch = [
247+
refs_data_batch = torch.cat([
253248
session[refs_idx_batch[session_id]]
254249
for session_id, session in enumerate(dataset.iter_sessions())
255-
]
256-
refs_data_batch_embeddings.append(super().transform(
257-
torch.cat(refs_data_batch, dim=1).squeeze(),
258-
pad_before_transform=pad_before_transform))
250+
],
251+
dim=1).squeeze()
252+
# refs_data_batch_embeddings.append(super().transform(
253+
# torch.cat(refs_data_batch, dim=1).squeeze(),
254+
# pad_before_transform=pad_before_transform))
255+
256+
if len(self.model.get_offset()) < 2 and pad_before_transform:
257+
pad_before_transform = False
258+
259+
self.model.eval()
260+
refs_data_batch_embeddings.append(
261+
self._transform(model=self.model,
262+
inputs=refs_data_batch,
263+
pad_before_transform=pad_before_transform,
264+
offset=self.model.get_offset(),
265+
batch_size=batch_size))
266+
259267
return torch.cat(refs_data_batch_embeddings, dim=0)
260268

261269
@torch.no_grad()

0 commit comments

Comments
 (0)