Skip to content

Commit d91949f

Browse files
committed
Implement extra review comments
1 parent 63d5a7c commit d91949f

File tree

1 file changed

+22
-88
lines changed

1 file changed

+22
-88
lines changed

cebra/solver/unified_session.py

Lines changed: 22 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,28 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22-
"""Solver implementations for unified-session datasets."""
22+
"""Unified session solver for multi-session contrastive learning.
23+
24+
We added support for training contrastive models on unified-session datasets.
25+
This allows users to align and embed multiple sessions into a common latent
26+
space using a single shared model.
27+
28+
This module implements the :py:class:`~cebra.solver.unified.UnifiedSolver`, which
29+
is designed for training a single embedding model across multiple recording sessions.
30+
Unlike the standard multi-session solvers, the unified session approach uses
31+
a global model that requires session-specific information for sampling but maintains
32+
a shared representation across all data.
33+
34+
Features:
35+
- Single model inference across all sessions.
36+
- Batched transform.
37+
- Compatibility with :py:class:`~cebra.data.UnifiedDataset` and :py:class:`~cebra.data.UnifiedLoader`.
38+
39+
See Also:
40+
:py:class:`~cebra.solver.base.Solver`
41+
:py:class:`~cebra.data.UnifiedDataset`
42+
:py:class:`~cebra.data.UnifiedLoader`
43+
"""
2344

2445
from typing import List, Optional, Union
2546

@@ -259,93 +280,6 @@ def transform(self,
259280

260281
return torch.cat(refs_data_batch_embeddings, dim=0)
261282

262-
@torch.no_grad()
263-
def single_session_transform(
264-
self,
265-
inputs: Union[torch.Tensor, List[torch.Tensor]],
266-
session_id: Optional[int] = None,
267-
pad_before_transform: bool = True,
268-
padding_mode: str = "zero",
269-
batch_size: Optional[int] = 100) -> torch.Tensor:
270-
"""Compute the embedding for the `session_id`th session of the dataset without labels alignment.
271-
272-
By padding the channels that don't correspond to the {session_id}th session, we can
273-
use a single session solver without behavioral alignment.
274-
275-
Note: The embedding will not benefit from the behavioral alignment, and consequently
276-
from the information contained in the other sessions. We expect single session encoder
277-
behavioral decoding performances.
278-
279-
Args:
280-
inputs: The input signal for all sessions.
281-
session_id: The session ID, an :py:class:`int` between 0 and
282-
the number of sessions.
283-
pad_before_transform: If True, pads the input before applying the transform.
284-
padding_mode: The mode to use for padding. Padding is done in the following
285-
ways, either by padding all the other sessions to the length of the
286-
{session_id}th session, or by resampling all sessions in a random way:
287-
- `time`: pads the inputs that are not inferred to the maximum length of
288-
the session and then zeros so that the length is the same as the
289-
{session_id}th session length.
290-
- `zero`: pads the inputs that are not inferred with zeros so that the
291-
length is the same as the {session_id}th session length.
292-
- `poisson`: pads the inputs that are not inferred with a poisson distribution
293-
so that the length is the same as the {session_id}th session length.
294-
- `random`: pads all sessions with random values sampled from a normal
295-
distribution.
296-
- `random_poisson`: pads all sessions with random values sampled from a
297-
poisson distribution.
298-
299-
batch_size: If not None, batched inference will be applied.
300-
301-
Returns:
302-
The output embedding for the session corresponding to the provided ID `session_id`. The shape
303-
is (num_samples(session_id), output_dimension)``.
304-
"""
305-
inputs = [session.to(self.device) for session in inputs]
306-
307-
zero_shape = inputs[session_id].shape[0]
308-
309-
if padding_mode == "time" or padding_mode == "zero" or padding_mode == "poisson":
310-
for i in range(len(inputs)):
311-
if i != session_id:
312-
if padding_mode == "time":
313-
if inputs[i].shape[0] >= zero_shape:
314-
inputs[i] = inputs[i][:zero_shape]
315-
else:
316-
inputs[i] = torch.cat(
317-
(inputs[i],
318-
torch.zeros(
319-
(zero_shape - inputs[i].shape[0],
320-
inputs[i].shape[1])).to(self.device)))
321-
if padding_mode == "poisson":
322-
inputs[i] = torch.poisson(
323-
torch.ones((zero_shape, inputs[i].shape[1])))
324-
if padding_mode == "zero":
325-
inputs[i] = torch.zeros(
326-
(zero_shape, inputs[i].shape[1]))
327-
padded_inputs = torch.cat(
328-
[session.to(self.device) for session in inputs], dim=1)
329-
330-
elif padding_mode == "random_poisson":
331-
padded_inputs = torch.poisson(
332-
torch.ones((zero_shape, self.n_features)))
333-
elif padding_mode == "random":
334-
padded_inputs = torch.normal(
335-
torch.zeros((zero_shape, self.n_features)),
336-
torch.ones((zero_shape, self.n_features)))
337-
338-
else:
339-
raise ValueError(
340-
f"Invalid padding mode: {padding_mode}. "
341-
"Choose from 'time', 'zero', 'poisson', 'random', or 'random_poisson'."
342-
)
343-
344-
# Single session solver transform call
345-
return super().transform(inputs=padded_inputs,
346-
pad_before_transform=pad_before_transform,
347-
batch_size=batch_size)
348-
349283
@torch.no_grad()
350284
def decoding(self,
351285
train_loader: cebra.data.Loader,

0 commit comments

Comments
 (0)