Skip to content

Commit 9db3e37

Browse files
CeliaBenquetstes
authored andcommitted
Add some coverage
1 parent c845ec3 commit 9db3e37

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

cebra/solver/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ def _get_loader(self, loader):
360360

361361
@abc.abstractmethod
362362
def _set_fitted_params(self, loader: cebra.data.Loader):
363+
"""Set parameters once the solver is fitted.
364+
365+
Args:
366+
loader: Loader used to fit the solver.
367+
"""
368+
363369
raise NotImplementedError
364370

365371
def fit(
@@ -507,6 +513,11 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
507513

508514
@abc.abstractmethod
509515
def _check_is_session_id_valid(self, session_id: Optional[int] = None):
516+
"""Check that the session ID provided is valid for the solver instance.
517+
518+
Args:
519+
session_id: The session ID to check.
520+
"""
510521
raise NotImplementedError
511522

512523
@abc.abstractmethod
@@ -530,7 +541,7 @@ def _select_model(
530541

531542
@torch.no_grad()
532543
def transform(self,
533-
inputs: torch.Tensor,
544+
inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray],
534545
pad_before_transform: bool = True,
535546
session_id: Optional[int] = None,
536547
batch_size: Optional[int] = None) -> torch.Tensor:

cebra/solver/multi_session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,17 @@ def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch:
126126
)
127127

128128
def _set_fitted_params(self, loader: cebra.data.Loader):
129+
"""Set parameters once the solver is fitted.
130+
131+
In multi session solver, the number of session is set to the number of
132+
sessions in the dataset of the loader and the number of
133+
features is set as a list corresponding to the number of neurons in
134+
each dataset.
135+
136+
Args:
137+
loader: Loader used to fit the solver.
138+
"""
139+
129140
self.num_sessions = loader.dataset.num_sessions
130141
self.n_features = [
131142
loader.dataset.get_input_dimension(session_id)
@@ -152,6 +163,14 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor,
152163
)
153164

154165
def _check_is_session_id_valid(self, session_id: Optional[int]):
166+
"""Check that the session ID provided is valid for the solver instance.
167+
168+
The session ID must be non-null and between 0 and the number session in the dataset.
169+
170+
Args:
171+
session_id: The session ID to check.
172+
"""
173+
155174
if session_id is None:
156175
raise RuntimeError(
157176
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."

cebra/solver/single_session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def parameters(self, session_id: Optional[int] = None):
5555
yield parameter
5656

5757
def _set_fitted_params(self, loader: cebra.data.Loader):
58+
"""Set parameters once the solver is fitted.
59+
60+
In single session solver, the number of session is set to None and the number of
61+
features is set to the number of neurons in the dataset.
62+
63+
Args:
64+
loader: Loader used to fit the solver.
65+
"""
5866
self.num_sessions = None
5967
self.n_features = loader.dataset.input_dimension
6068

@@ -77,6 +85,14 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int):
7785
)
7886

7987
def _check_is_session_id_valid(self, session_id: Optional[int] = None):
88+
"""Check that the session ID provided is valid for the solver instance.
89+
90+
The session ID must be null or equal to 0.
91+
92+
Args:
93+
session_id: The session ID to check.
94+
"""
95+
8096
if session_id is not None and session_id > 0:
8197
raise RuntimeError(
8298
f"Invalid session_id {session_id}: single session models only takes an optional null session_id."

0 commit comments

Comments
 (0)