@@ -284,7 +284,7 @@ def fit(
284284 if not end_to_end :
285285 self .freeze ("body" )
286286
287- dataloader = self ._prepare_dataloader (x_train , y_train , batch_size , max_length )
287+ dataloader = self ._prepare_dataloader (list ( x_train ), list ( y_train ) , batch_size , max_length )
288288 criterion = self .model_head .get_loss_fn ()
289289 optimizer = self ._prepare_optimizer (head_learning_rate , body_learning_rate , l2_weight )
290290 scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 5 , gamma = 0.5 )
@@ -314,8 +314,8 @@ def fit(
314314 if not end_to_end :
315315 self .unfreeze ("body" )
316316 else : # train with sklearn
317- embeddings = self .model_body .encode (x_train , normalize_embeddings = self .normalize_embeddings )
318- self .model_head .fit (embeddings , y_train )
317+ embeddings = self .model_body .encode (list ( x_train ) , normalize_embeddings = self .normalize_embeddings )
318+ self .model_head .fit (embeddings , list ( y_train ) )
319319 if self .labels is None and self .multi_target_strategy is None :
320320 # Try to set the labels based on the head classes, if they exist
321321 # This can fail in various ways, so we catch all exceptions
@@ -477,6 +477,7 @@ def _output_type_conversion(
477477 outputs = torch .from_numpy (outputs )
478478 return outputs
479479
480+ @torch .no_grad ()
480481 def predict_proba (
481482 self ,
482483 inputs : Union [str , List [str ]],
@@ -521,6 +522,7 @@ def predict_proba(
521522 outputs = self ._output_type_conversion (probs , as_numpy = as_numpy )
522523 return outputs [0 ] if is_singular else outputs
523524
525+ @torch .no_grad ()
524526 def predict (
525527 self ,
526528 inputs : Union [str , List [str ]],
@@ -556,7 +558,7 @@ def predict(
556558 is_singular = isinstance (inputs , str )
557559 if is_singular :
558560 inputs = [inputs ]
559- embeddings = self .encode (inputs , batch_size = batch_size , show_progress_bar = show_progress_bar )
561+ embeddings = self .encode (list ( inputs ) , batch_size = batch_size , show_progress_bar = show_progress_bar )
560562 preds = self .model_head .predict (embeddings )
561563 # If labels are defined, we don't have multilabels & the output is not already strings, then we convert to string labels
562564 if (
0 commit comments