Skip to content

Commit 59df402

Browse files
gonlairostes
authored andcommitted
convert to torch if numpy array as inputs
1 parent 8c8be85 commit 59df402

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,10 +1233,13 @@ def transform(self,
12331233

12341234
sklearn_utils_validation.check_is_fitted(self, "n_features_")
12351235
# Input validation
1236+
#TODO: if inputs are in cuda, then it throws an error, deal with this.
12361237
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
12371238
input_dtype = X.dtype
1238-
#print(type(X))
1239-
#print(X.dtype)
1239+
1240+
if isinstance(X, np.ndarray):
1241+
X = torch.from_numpy(X)
1242+
# TODO: which type and device should be put there?
12401243

12411244
with torch.no_grad():
12421245
output = self.solver_.transform(

0 commit comments

Comments
 (0)