2121#
2222"""Solver implementations for unified-session datasets."""
2323
24- from typing import List , Optional , Tuple , Union
24+ from typing import List , Optional , Union
2525
2626import literate_dataclasses as dataclasses
2727import numpy as np
@@ -93,9 +93,10 @@ def _check_is_inputs_valid(self, inputs: Union[torch.Tensor,
9393 f"(n_samples, { self .n_features } ), got (n_samples, { inputs .shape [1 ]} )."
9494 )
9595
96- def _check_is_session_id_valid (self ,
97- session_id : Optional [int ] = None
98- ): # same as multi
96+ def _check_is_session_id_valid (
97+ self ,
98+ session_id : Optional [int ] = None ,
99+ ): # same as multi
99100 """Check that the session ID provided is valid for the solver instance.
100101
101102 The session ID must be non-null and between 0 and the number session in the dataset.
@@ -106,33 +107,27 @@ def _check_is_session_id_valid(self,
106107
107108 if session_id is None :
108109 raise RuntimeError (
109- "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape ."
110+ "No session_id provided: unified model requires a session_id as the target session to use to align the sessions ."
110111 )
111112 if session_id >= self .num_sessions or session_id < 0 :
112113 raise RuntimeError (
113- f"Invalid session_id { session_id } : session_id for the current multisession model must be between 0 and { self .num_sessions - 1 } ."
114+ f"Invalid session_id { session_id } : session_id for the current unified model must be between 0 and { self .num_sessions - 1 } ."
114115 )
115116
116- def _select_model (
117- self ,
118- inputs : Union [torch .Tensor , List [torch .Tensor ]],
119- session_id : Optional [int ] = None
120- ) -> Tuple [Union [List [torch .nn .Module ], torch .nn .Module ],
121- cebra .data .datatypes .Offset ]:
122- """ Select the model based on the input dimension and session ID.
117+ def _get_model (self , session_id : Optional [int ] = None ):
118+ """Get the model for the given session ID.
123119
124120 Args:
125- inputs: Data to infer using the selected model.
126121 session_id: The session ID, an :py:class:`int` between 0 and
127122 the number of sessions -1 for multisession, and set to
128123 ``None`` for single session.
129124
130125 Returns:
131- The model (first returns) and the offset of the model (second returns) .
126+ The model for the given session ID .
132127 """
133- model = self .model
134- offset = model . get_offset ()
135- return model , offset
128+ self ._check_is_session_id_valid ( session_id = session_id )
129+ self . _check_is_fitted ()
130+ return self . model
136131
137132 def _single_model_inference (self , batch : cebra .data .Batch ,
138133 model : torch .nn .Module ) -> cebra .data .Batch :
@@ -249,13 +244,26 @@ def transform(self,
249244 ref_idx = torch .arange (batch_start , batch_end ),
250245 session_id = session_id ).to (self .device )
251246
252- refs_data_batch = [
247+ refs_data_batch = torch . cat ( [
253248 session [refs_idx_batch [session_id ]]
254249 for session_id , session in enumerate (dataset .iter_sessions ())
255- ]
256- refs_data_batch_embeddings .append (super ().transform (
257- torch .cat (refs_data_batch , dim = 1 ).squeeze (),
258- pad_before_transform = pad_before_transform ))
250+ ],
251+ dim = 1 ).squeeze ()
252+ # refs_data_batch_embeddings.append(super().transform(
253+ # torch.cat(refs_data_batch, dim=1).squeeze(),
254+ # pad_before_transform=pad_before_transform))
255+
256+ if len (self .model .get_offset ()) < 2 and pad_before_transform :
257+ pad_before_transform = False
258+
259+ self .model .eval ()
260+ refs_data_batch_embeddings .append (
261+ self ._transform (model = self .model ,
262+ inputs = refs_data_batch ,
263+ pad_before_transform = pad_before_transform ,
264+ offset = self .model .get_offset (),
265+ batch_size = batch_size ))
266+
259267 return torch .cat (refs_data_batch_embeddings , dim = 0 )
260268
261269 @torch .no_grad ()
0 commit comments