@@ -226,6 +226,12 @@ def __init__(
226226 model_size = str (model_size ),
227227 )
228228
229+ def cleanup (self ):
230+ """Clean up operations if needed, such as closing an endpoint."""
231+ del self .model
232+ del self ._tokenizer
233+ torch .cuda .empty_cache ()
234+
229235 @classmethod
230236 def from_model (
231237 cls ,
@@ -543,7 +549,7 @@ def greedy_until(
543549 longest_context_continuation_size_in_split , self .max_length
544550 )
545551 batch_size = self ._get_batch_size (
546- override_bs = self .batch_size ,
552+ override_bs = self .config . batch_size ,
547553 max_input_length = max_context_continuation_size_allowed ,
548554 starting_batch_size = starting_batch_size ,
549555 )
@@ -710,7 +716,6 @@ def _generate(
710716 def loglikelihood (
711717 self ,
712718 requests : list [LoglikelihoodRequest ],
713- override_bs : Optional [int ] = None ,
714719 ) -> list [LoglikelihoodResponse ]:
715720 """Tokenize the context and continuation and compute the log likelihood of those
716721 tokenized sequences.
@@ -731,12 +736,11 @@ def loglikelihood(
731736 request .context , request .choice , pairwise = self .pairwise_tokenization
732737 )
733738
734- return self ._loglikelihood_tokens (requests , override_bs = override_bs )
739+ return self ._loglikelihood_tokens (requests )
735740
736741 def loglikelihood_rolling (
737742 self ,
738743 requests : list [LoglikelihoodRollingRequest ],
739- override_bs = None ,
740744 ) -> list [LoglikelihoodResponse ]:
741745 """This function is used to compute the log likelihood of the context for perplexity metrics."""
742746
@@ -746,7 +750,6 @@ def loglikelihood_rolling(
746750
747751 results = self ._loglikelihood_tokens (
748752 requests ,
749- override_bs = override_bs ,
750753 return_bool_score = False ,
751754 rolling = True ,
752755 )
@@ -755,7 +758,6 @@ def loglikelihood_rolling(
755758 def _loglikelihood_tokens (
756759 self ,
757760 requests : list [LoglikelihoodRequest ],
758- override_bs : int = - 1 ,
759761 return_bool_score : bool = True ,
760762 rolling : bool = False ,
761763 ) -> list [LoglikelihoodResponse ]:
@@ -774,7 +776,7 @@ def _loglikelihood_tokens(
774776 )
775777
776778 batch_size = self ._get_batch_size (
777- override_bs = override_bs ,
779+ override_bs = self . config . batch_size ,
778780 max_input_length = max_context_continuation_size_allowed ,
779781 starting_batch_size = starting_batch_size ,
780782 )
@@ -967,7 +969,8 @@ def pad_and_gather(
967969 return output_tensor , length_tensor
968970
969971 def loglikelihood_single_token (
970- self , requests : list [LoglikelihoodSingleTokenRequest ], override_bs : Optional [int ] = None
972+ self ,
973+ requests : list [LoglikelihoodSingleTokenRequest ],
971974 ) -> list [LoglikelihoodSingleTokenResponse ]:
972975 """Tokenize the context and continuation and compute the log likelihood of those
973976 tokenized sequences.
@@ -996,10 +999,11 @@ def loglikelihood_single_token(
996999 )
9971000 request .tokenized_continuation = continuations_enc
9981001
999- return self ._loglikelihood_single_token (requests , override_bs = override_bs )
1002+ return self ._loglikelihood_single_token (requests )
10001003
10011004 def _loglikelihood_single_token (
1002- self , requests : list [LoglikelihoodSingleTokenRequest ], override_bs : int = - 1
1005+ self ,
1006+ requests : list [LoglikelihoodSingleTokenRequest ],
10031007 ) -> list [LoglikelihoodSingleTokenResponse ]:
10041008 dataset = LoglikelihoodSingleTokenDataset (requests = requests , num_dataset_splits = self .DATASET_SPLITS )
10051009 starting_batch_size = STARTING_BATCH_SIZE
@@ -1008,7 +1012,7 @@ def _loglikelihood_single_token(
10081012 for split_start , split_end in tqdm (dataset .splits_start_end_iterator ()):
10091013 context_enc = dataset [0 ].tokenized_context
10101014 max_context = len (context_enc [- self .max_length :])
1011- batch_size = self ._get_batch_size (override_bs = override_bs , max_input_length = max_context )
1015+ batch_size = self ._get_batch_size (override_bs = self . config . batch_size , max_input_length = max_context )
10121016 starting_batch_size = batch_size * 2
10131017
10141018 dataloader = DataLoader (dataset , batch_size = starting_batch_size , collate_fn = lambda batch : batch )
0 commit comments