File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
cebra/integrations/sklearn Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -1287,7 +1287,7 @@ def transform_deprecated(self,
12871287
12881288 # Input validation
12891289 X = sklearn_utils .check_input_array (X , min_samples = len (self .offset_ ))
1290- input_dtype = X .dtype
1290+ # input_dtype = X.dtype
12911291
12921292 with torch .no_grad ():
12931293 model .eval ()
@@ -1305,10 +1305,11 @@ def transform_deprecated(self,
13051305 # Standard evaluation, (T, C, dt)
13061306 output = model (X ).cpu ().numpy ()
13071307
1308- if input_dtype == "float64" :
1309- return output .astype ( input_dtype )
1308+ #TODO: check if this is safe.
1309+ return output .numpy ( force = True )
13101310
1311- return output
1311+ #if input_dtype == "float64":
1312+ # return output.astype(input_dtype)
13121313
13131314 def fit_transform (
13141315 self ,
You can’t perform that action at this time.
0 commit comments