Skip to content

Commit b73c123

Browse files
gonlairoCeliaBenquet
authored andcommitted
remove float16
1 parent ad56472 commit b73c123

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ def transform_deprecated(self,
12871287

12881288
# Input validation
12891289
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1290-
input_dtype = X.dtype
1290+
#input_dtype = X.dtype
12911291

12921292
with torch.no_grad():
12931293
model.eval()
@@ -1305,10 +1305,11 @@ def transform_deprecated(self,
13051305
# Standard evaluation, (T, C, dt)
13061306
output = model(X).cpu().numpy()
13071307

1308-
if input_dtype == "float64":
1309-
return output.astype(input_dtype)
1308+
#TODO: check if this is safe.
1309+
return output.numpy(force=True)
13101310

1311-
return output
1311+
#if input_dtype == "float64":
1312+
# return output.astype(input_dtype)
13121313

13131314
def fit_transform(
13141315
self,

0 commit comments

Comments
 (0)