Skip to content

Commit 217a8a7

Browse files
committed
Fix all tests but xcebra tests
1 parent acd2111 commit 217a8a7

File tree

6 files changed

+431
-410
lines changed

6 files changed

+431
-410
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,14 +1053,12 @@ def _partial_fit(
10531053

10541054
# Save variables of interest as semi-private attributes
10551055
self.model_ = model
1056-
self.n_features_ = ([
1057-
loader.dataset.get_input_dimension(session_id)
1058-
for session_id in range(loader.dataset.num_sessions)
1059-
] if is_multisession else loader.dataset.input_dimension)
1056+
1057+
self.n_features_ = solver.n_features
1058+
self.num_sessions_ = solver.num_sessions
10601059
self.solver_ = solver
10611060
self.n_features_in_ = ([model[n].num_input for n in range(len(model))]
10621061
if is_multisession else model.num_input)
1063-
self.num_sessions_ = loader.dataset.num_sessions if is_multisession else None
10641062

10651063
return self
10661064

@@ -1256,7 +1254,7 @@ def transform(self,
12561254

12571255
return output.detach().cpu().numpy()
12581256

1259-
# Deprecated, kept for testing.
1257+
#NOTE: Deprecated, as transform is now handled in the solver but kept for testing.
12601258
def transform_deprecated(self,
12611259
X: Union[npt.NDArray, torch.Tensor],
12621260
session_id: Optional[int] = None) -> npt.NDArray:

cebra/solver/base.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ def __getitem__(self, idx):
234234
index_dataset = IndexDataset(inputs)
235235
index_dataloader = DataLoader(index_dataset, batch_size=batch_size)
236236

237+
if len(index_dataloader) < 2:
238+
raise ValueError(
239+
f"Number of batches must be greater than 1, you can use transform without batching instead, got {len(index_dataloader)}."
240+
)
241+
237242
output = []
238243
for batch_idx, index_batch in enumerate(index_dataloader):
239244
# NOTE(celia): This is to prevent that adding the offset to the
@@ -449,6 +454,9 @@ def fit(
449454
if logdir is not None:
450455
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
451456

457+
assert hasattr(self, "n_features")
458+
assert hasattr(self, "num_sessions")
459+
452460
def step(self, batch: cebra.data.Batch) -> dict:
453461
"""Perform a single gradient update.
454462
@@ -564,10 +572,8 @@ def _select_model(
564572
"""
565573
raise NotImplementedError
566574

567-
@property
568575
def _check_is_fitted(self):
569-
#NOTE(celia): instead of hasattr(model, "n_features_"), double check this!
570-
if not (hasattr(self, "history") and len(self.history) > 0):
576+
if not hasattr(self, "n_features"):
571577
raise ValueError(
572578
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
573579
"appropriate arguments before using this estimator.")
@@ -598,15 +604,6 @@ def transform(self,
598604
Returns:
599605
The output embedding.
600606
"""
601-
if not self.is_fitted:
602-
raise ValueError(
603-
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
604-
"appropriate arguments before using this estimator.")
605-
606-
if batch_size is not None and batch_size < 1:
607-
raise ValueError(
608-
f"Batch size should be at least 1, got {batch_size}")
609-
610607
if isinstance(inputs, list):
611608
raise ValueError(
612609
"Inputs to transform() should be the data for a single session, but received a list."
@@ -623,7 +620,7 @@ def transform(self,
623620
pad_before_transform = False
624621

625622
model.eval()
626-
if batch_size is not None:
623+
if batch_size is not None and inputs.shape[0] > int(batch_size * 2):
627624
output = _batched_transform(
628625
model=model,
629626
inputs=inputs,

cebra/solver/multi_session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _check_is_session_id_valid(self, session_id: Optional[int]):
177177
)
178178

179179
def _select_model(self, inputs: torch.Tensor, session_id: Optional[int]):
180-
""" Select the model based on the input dimension and session ID.
180+
""" Select the (trained) model based on the input dimension and session ID.
181181
182182
Args:
183183
inputs: Data to infer using the selected model.
@@ -189,6 +189,7 @@ def _select_model(self, inputs: torch.Tensor, session_id: Optional[int]):
189189
The model (first returns) and the offset of the model (second returns).
190190
"""
191191
self._check_is_session_id_valid(session_id=session_id)
192+
self._check_is_fitted()
192193
self._check_is_inputs_valid(inputs, session_id=session_id)
193194

194195
model = self.model[session_id]

cebra/solver/single_session.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _select_model(
103103
List[torch.Tensor]], session_id: Optional[int]
104104
) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module],
105105
cebra.data.datatypes.Offset]:
106-
""" Select the model based on the input dimension and session ID.
106+
""" Select the (trained) model based on the input dimension and session ID.
107107
108108
Args:
109109
inputs: Data to infer using the selected model.
@@ -114,8 +114,9 @@ def _select_model(
114114
Returns:
115115
The model (first returns) and the offset of the model (second returns).
116116
"""
117-
self._check_is_inputs_valid(inputs, session_id=session_id)
118117
self._check_is_session_id_valid(session_id=session_id)
118+
self._check_is_fitted()
119+
self._check_is_inputs_valid(inputs, session_id=session_id)
119120

120121
model = self.model
121122
offset = model.get_offset()
@@ -228,7 +229,7 @@ def _select_model(
228229
List[torch.Tensor]], session_id: Optional[int]
229230
) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module],
230231
cebra.data.datatypes.Offset]:
231-
""" Select the model based on the input dimension and session ID.
232+
""" Select the (trained) model based on the input dimension and session ID.
232233
233234
Args:
234235
inputs: Data to infer using the selected model.
@@ -239,8 +240,9 @@ def _select_model(
239240
Returns:
240241
The model (first returns) and the offset of the model (second returns).
241242
"""
242-
self._check_is_inputs_valid(inputs, session_id=session_id)
243243
self._check_is_session_id_valid(session_id=session_id)
244+
self._check_is_fitted()
245+
self._check_is_inputs_valid(inputs, session_id=session_id)
244246

245247
model = self.model.module
246248
if hasattr(model, 'get_offset'):

0 commit comments

Comments
 (0)