We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8c8be85 commit 59df402Copy full SHA for 59df402
cebra/integrations/sklearn/cebra.py
@@ -1233,10 +1233,13 @@ def transform(self,
1233
1234
sklearn_utils_validation.check_is_fitted(self, "n_features_")
1235
# Input validation
1236
+ #TODO: if inputs are in cuda, then it throws an error, deal with this.
1237
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1238
input_dtype = X.dtype
- #print(type(X))
1239
- #print(X.dtype)
+
1240
+ if isinstance(X, np.ndarray):
1241
+ X = torch.from_numpy(X)
1242
+ # TODO: which type and device should be put there?
1243
1244
with torch.no_grad():
1245
output = self.solver_.transform(
0 commit comments