|
19 | 19 | # See the License for the specific language governing permissions and |
20 | 20 | # limitations under the License. |
21 | 21 | # |
22 | | -"""Solver implementations for unified-session datasets.""" |
| 22 | +"""Unified session solver for multi-session contrastive learning. |
| 23 | +
|
| 24 | +We added support for training contrastive models on unified-session datasets. |
| 25 | +This allows users to align and embed multiple sessions into a common latent |
| 26 | +space using a single shared model. |
| 27 | +
|
| 28 | +This module implements the :py:class:`~cebra.solver.unified.UnifiedSolver`, which |
| 29 | +is designed for training a single embedding model across multiple recording sessions. |
| 30 | +Unlike the standard multi-session solvers, the unified session approach uses |
| 31 | +a global model that requires session-specific information for sampling but maintains |
| 32 | +a shared representation across all data. |
| 33 | +
|
| 34 | +Features: |
| 35 | +- Single model inference across all sessions. |
| 36 | +- Batched transform. |
| 37 | +- Compatibility with :py:class:`~cebra.data.UnifiedDataset` and :py:class:`~cebra.data.UnifiedLoader`. |
| 38 | +
|
| 39 | +See Also: |
| 40 | + :py:class:`~cebra.solver.base.Solver` |
| 41 | + :py:class:`~cebra.data.UnifiedDataset` |
| 42 | + :py:class:`~cebra.data.UnifiedLoader` |
| 43 | +""" |
23 | 44 |
|
24 | 45 | from typing import List, Optional, Union |
25 | 46 |
|
@@ -259,93 +280,6 @@ def transform(self, |
259 | 280 |
|
260 | 281 | return torch.cat(refs_data_batch_embeddings, dim=0) |
261 | 282 |
|
262 | | - @torch.no_grad() |
263 | | - def single_session_transform( |
264 | | - self, |
265 | | - inputs: Union[torch.Tensor, List[torch.Tensor]], |
266 | | - session_id: Optional[int] = None, |
267 | | - pad_before_transform: bool = True, |
268 | | - padding_mode: str = "zero", |
269 | | - batch_size: Optional[int] = 100) -> torch.Tensor: |
270 | | - """Compute the embedding for the `session_id`th session of the dataset without labels alignment. |
271 | | -
|
272 | | - By padding the channels that don't correspond to the {session_id}th session, we can |
273 | | - use a single session solver without behavioral alignment. |
274 | | -
|
275 | | - Note: The embedding will not benefit from the behavioral alignment, and consequently |
276 | | - from the information contained in the other sessions. We expect single session encoder |
277 | | - behavioral decoding performances. |
278 | | -
|
279 | | - Args: |
280 | | - inputs: The input signal for all sessions. |
281 | | - session_id: The session ID, an :py:class:`int` between 0 and |
282 | | - the number of sessions. |
283 | | - pad_before_transform: If True, pads the input before applying the transform. |
284 | | - padding_mode: The mode to use for padding. Padding is done in the following |
285 | | - ways, either by padding all the other sessions to the length of the |
286 | | - {session_id}th session, or by resampling all sessions in a random way: |
287 | | - - `time`: pads the inputs that are not inferred to the maximum length of |
288 | | - the session and then zeros so that the length is the same as the |
289 | | - {session_id}th session length. |
290 | | - - `zero`: pads the inputs that are not inferred with zeros so that the |
291 | | - length is the same as the {session_id}th session length. |
292 | | - - `poisson`: pads the inputs that are not inferred with a poisson distribution |
293 | | - so that the length is the same as the {session_id}th session length. |
294 | | - - `random`: pads all sessions with random values sampled from a normal |
295 | | - distribution. |
296 | | - - `random_poisson`: pads all sessions with random values sampled from a |
297 | | - poisson distribution. |
298 | | -
|
299 | | - batch_size: If not None, batched inference will be applied. |
300 | | -
|
301 | | - Returns: |
302 | | - The output embedding for the session corresponding to the provided ID `session_id`. The shape |
303 | | - is (num_samples(session_id), output_dimension)``. |
304 | | - """ |
305 | | - inputs = [session.to(self.device) for session in inputs] |
306 | | - |
307 | | - zero_shape = inputs[session_id].shape[0] |
308 | | - |
309 | | - if padding_mode == "time" or padding_mode == "zero" or padding_mode == "poisson": |
310 | | - for i in range(len(inputs)): |
311 | | - if i != session_id: |
312 | | - if padding_mode == "time": |
313 | | - if inputs[i].shape[0] >= zero_shape: |
314 | | - inputs[i] = inputs[i][:zero_shape] |
315 | | - else: |
316 | | - inputs[i] = torch.cat( |
317 | | - (inputs[i], |
318 | | - torch.zeros( |
319 | | - (zero_shape - inputs[i].shape[0], |
320 | | - inputs[i].shape[1])).to(self.device))) |
321 | | - if padding_mode == "poisson": |
322 | | - inputs[i] = torch.poisson( |
323 | | - torch.ones((zero_shape, inputs[i].shape[1]))) |
324 | | - if padding_mode == "zero": |
325 | | - inputs[i] = torch.zeros( |
326 | | - (zero_shape, inputs[i].shape[1])) |
327 | | - padded_inputs = torch.cat( |
328 | | - [session.to(self.device) for session in inputs], dim=1) |
329 | | - |
330 | | - elif padding_mode == "random_poisson": |
331 | | - padded_inputs = torch.poisson( |
332 | | - torch.ones((zero_shape, self.n_features))) |
333 | | - elif padding_mode == "random": |
334 | | - padded_inputs = torch.normal( |
335 | | - torch.zeros((zero_shape, self.n_features)), |
336 | | - torch.ones((zero_shape, self.n_features))) |
337 | | - |
338 | | - else: |
339 | | - raise ValueError( |
340 | | - f"Invalid padding mode: {padding_mode}. " |
341 | | - "Choose from 'time', 'zero', 'poisson', 'random', or 'random_poisson'." |
342 | | - ) |
343 | | - |
344 | | - # Single session solver transform call |
345 | | - return super().transform(inputs=padded_inputs, |
346 | | - pad_before_transform=pad_before_transform, |
347 | | - batch_size=batch_size) |
348 | | - |
349 | 283 | @torch.no_grad() |
350 | 284 | def decoding(self, |
351 | 285 | train_loader: cebra.data.Loader, |
|
0 commit comments