|
21 | 21 | # |
22 | 22 | """Solver implementations for multi-session datasetes.""" |
23 | 23 |
|
24 | | -from typing import List, Optional |
| 24 | +import copy |
| 25 | +from typing import List, Optional, Tuple, Union |
25 | 26 |
|
| 27 | +import numpy.typing as npt |
26 | 28 | import torch |
27 | 29 |
|
28 | 30 | import cebra |
@@ -241,7 +243,16 @@ class MultiSessionAuxVariableSolver(MultiSessionSolver): |
241 | 243 | """Multi session training, contrasting neural data against behavior.""" |
242 | 244 |
|
243 | 245 | _variant_name = "multi-session-aux" |
244 | | - reference_model: torch.nn.Module |
| 246 | + reference_model: torch.nn.Module = None |
| 247 | + |
| 248 | + def __post_init__(self): |
| 249 | + super().__post_init__() |
| 250 | + if self.reference_model is None: |
| 251 | + # NOTE(stes): This should work, according to this thread |
| 252 | + # https://discuss.pytorch.org/t/can-i-deepcopy-a-model/52192/19 |
| 253 | + # and create a true copy of the model. |
| 254 | + self.reference_model = copy.deepcopy(self.model) |
| 255 | + self.reference_model.to(self.device) |
245 | 256 |
|
246 | 257 | def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: |
247 | 258 | """Given batches of input examples, computes the feature representations/embeddings. |
@@ -276,3 +287,83 @@ def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: |
276 | 287 | positive=pos.view(-1, num_features), |
277 | 288 | negative=neg.view(-1, num_features), |
278 | 289 | ) |
| 290 | + |
| 291 | + def _select_model( |
| 292 | + self, |
| 293 | + inputs: Union[torch.Tensor, List[torch.Tensor]], |
| 294 | + session_id: Optional[int] = None, |
| 295 | + use_reference_model: bool = False, |
| 296 | + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], |
| 297 | + cebra.data.datatypes.Offset]: |
| 298 | + """ Select the model based on the input dimension and session ID. |
| 299 | +
|
| 300 | + Args: |
| 301 | + inputs: Data to infer using the selected model. |
| 302 | + session_id: The session ID, an :py:class:`int` between 0 and |
| 303 | + the number of sessions -1 for multisession, and set to |
| 304 | + ``None`` for single session. |
| 305 | +
|
| 306 | + Returns: |
| 307 | + The model (first returns) and the offset of the model (second returns). |
| 308 | + """ |
| 309 | + self._check_is_inputs_valid(inputs, session_id=session_id) |
| 310 | + self._check_is_session_id_valid(session_id=session_id) |
| 311 | + |
| 312 | + if use_reference_model: |
| 313 | + model = self.reference_model[session_id] |
| 314 | + else: |
| 315 | + model = self.model[session_id] |
| 316 | + offset = model.get_offset() |
| 317 | + return model, offset |
| 318 | + |
| 319 | + @torch.no_grad() |
| 320 | + def transform(self, |
| 321 | + inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray], |
| 322 | + pad_before_transform: bool = True, |
| 323 | + session_id: Optional[int] = None, |
| 324 | + batch_size: Optional[int] = None, |
| 325 | + use_reference_model: bool = False) -> torch.Tensor: |
| 326 | + """Compute the embedding. |
| 327 | + This function by default use ``model`` that was trained to encode the positive |
| 328 | + and negative samples. To use ``reference_model`` instead of ``model`` |
| 329 | + ``use_reference_model`` should be equal ``True``. |
| 330 | + Args: |
| 331 | + inputs: The input signal |
| 332 | + use_reference_model: Flag for using ``reference_model`` |
| 333 | + Returns: |
| 334 | + The output embedding. |
| 335 | + """ |
| 336 | + if isinstance(inputs, list): |
| 337 | + raise NotImplementedError( |
| 338 | + "Inputs to transform() should be the data for a single session." |
| 339 | + ) |
| 340 | + elif not isinstance(inputs, torch.Tensor): |
| 341 | + raise ValueError( |
| 342 | + f"Inputs should be a torch.Tensor, not {type(inputs)}.") |
| 343 | + |
| 344 | + if not hasattr(self, "history") and len(self.history) > 0: |
| 345 | + raise ValueError( |
| 346 | + f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " |
| 347 | + "appropriate arguments before using this estimator.") |
| 348 | + model, offset = self._select_model( |
| 349 | + inputs, session_id, use_reference_model=use_reference_model) |
| 350 | + |
| 351 | + if len(offset) < 2 and pad_before_transform: |
| 352 | + pad_before_transform = False |
| 353 | + |
| 354 | + model.eval() |
| 355 | + if batch_size is not None: |
| 356 | + output = abc_._batched_transform( |
| 357 | + model=model, |
| 358 | + inputs=inputs, |
| 359 | + offset=offset, |
| 360 | + batch_size=batch_size, |
| 361 | + pad_before_transform=pad_before_transform, |
| 362 | + ) |
| 363 | + else: |
| 364 | + output = abc_._transform(model=model, |
| 365 | + inputs=inputs, |
| 366 | + offset=offset, |
| 367 | + pad_before_transform=pad_before_transform) |
| 368 | + |
| 369 | + return output |
0 commit comments