Skip to content

Commit 5e7a14c

Browse files
gonlairostes
authored andcommitted
remove float16
1 parent bc8ee25 commit 5e7a14c

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,7 @@ def transform(self,
12351235
# Input validation
12361236
#TODO: if inputs are in cuda, then it throws an error, deal with this.
12371237
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1238-
input_dtype = X.dtype
1238+
#input_dtype = X.dtype
12391239

12401240
if isinstance(X, np.ndarray):
12411241
X = torch.from_numpy(X)
@@ -1248,10 +1248,11 @@ def transform(self,
12481248
session_id=session_id,
12491249
batch_size=batch_size)
12501250

1251-
if input_dtype == "float64":
1252-
return output.astype(input_dtype)
1251+
#TODO: check if this is safe.
1252+
return output.numpy(force=True)
12531253

1254-
return output
1254+
#if input_dtype == "float64":
1255+
# return output.astype(input_dtype)
12551256

12561257
def fit_transform(
12571258
self,

cebra/integrations/sklearn/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
7878
X,
7979
accept_sparse=False,
8080
accept_large_sparse=False,
81-
dtype=("float16", "float32", "float64"),
81+
# NOTE: remove float16 because F.pad does not allow float16.
82+
dtype=("float32", "float64"),
8283
order=None,
8384
copy=False,
8485
force_all_finite=True,

0 commit comments

Comments
 (0)