We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9c46eb9 commit c845ec3Copy full SHA for c845ec3
cebra/integrations/sklearn/cebra.py
@@ -1196,8 +1196,8 @@ def transform(self,
1196
>>> embedding = cebra_model.transform(dataset)
1197
1198
"""
1199
- self.solver_._check_is_session_id_valid(session_id=session_id)
1200
sklearn_utils_validation.check_is_fitted(self, "n_features_")
+ self.solver_._check_is_session_id_valid(session_id=session_id)
1201
1202
if torch.is_tensor(X) and X.device.type == "cuda":
1203
X = X.detach().cpu()
0 commit comments