Skip to content

Commit b417a23

Browse files
CeliaBenquetstes
authored andcommitted
Improve modularity remove duplicate code and todos
1 parent 9fe3af3 commit b417a23

File tree

5 files changed

+359
-178
lines changed

5 files changed

+359
-178
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -791,33 +791,7 @@ def _configure_for_all(
791791

792792
def _select_model(self, X: Union[npt.NDArray, torch.Tensor],
793793
session_id: int):
794-
# Choose the model and get its corresponding offset
795-
if self.num_sessions is not None: # multisession implementation
796-
if session_id is None:
797-
raise RuntimeError(
798-
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."
799-
)
800-
if session_id >= self.num_sessions or session_id < 0:
801-
raise RuntimeError(
802-
f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}."
803-
)
804-
if self.n_features_[session_id] != X.shape[1]:
805-
raise ValueError(
806-
f"Invalid input shape: model for session {session_id} requires an input of shape"
807-
f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})."
808-
)
809-
810-
model = self.model_[session_id]
811-
model.to(self.device_)
812-
else: # single session
813-
if session_id is not None and session_id > 0:
814-
raise RuntimeError(
815-
f"Invalid session_id {session_id}: single session models only takes an optional null session_id."
816-
)
817-
model = self.model_
818-
819-
offset = model.get_offset()
820-
return model, offset
794+
return self.solver_._select_model(X, session_id=session_id)
821795

822796
def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
823797
"""Check that the input labels are compatible with the labels used to fit the model.
@@ -1224,16 +1198,16 @@ def transform(self,
12241198
>>> embedding = cebra_model.transform(dataset)
12251199
12261200
"""
1227-
1201+
self.solver_._check_is_session_id_valid(session_id=session_id)
12281202
sklearn_utils_validation.check_is_fitted(self, "n_features_")
1229-
# Input validation
1230-
#TODO: if inputs are in cuda, then it throws an error, deal with this.
1203+
1204+
if torch.is_tensor(X) and X.device.type == "cuda":
1205+
X = X.detach().cpu()
1206+
12311207
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1232-
#input_dtype = X.dtype
12331208

12341209
if isinstance(X, np.ndarray):
12351210
X = torch.from_numpy(X)
1236-
# TODO: which type and device should be put there?
12371211

12381212
with torch.no_grad():
12391213
output = self.solver_.transform(
@@ -1242,11 +1216,7 @@ def transform(self,
12421216
session_id=session_id,
12431217
batch_size=batch_size)
12441218

1245-
#TODO: check if this is safe.
1246-
return output.numpy(force=True)
1247-
1248-
#if input_dtype == "float64":
1249-
# return output.astype(input_dtype)
1219+
return output.detach().cpu().numpy()
12501220

12511221
def fit_transform(
12521222
self,

cebra/integrations/sklearn/metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def infonce_loss(
8383
f"got {len(y[0])} sessions.")
8484

8585
model, _ = cebra_model._select_model(
86-
X, session_id) # check session_id validity and corresponding model
86+
X, session_id=session_id
87+
) # check session_id validity and corresponding model
8788
cebra_model._check_labels_types(y, session_id=session_id)
8889

8990
dataset, is_multisession = cebra_model._prepare_data(X, y) # single session

0 commit comments

Comments
 (0)