@@ -226,6 +226,12 @@ def __init__(
226
226
model_size = str (model_size ),
227
227
)
228
228
229
+ def cleanup (self ):
230
+ """Clean up operations if needed, such as closing an endpoint."""
231
+ del self .model
232
+ del self ._tokenizer
233
+ torch .cuda .empty_cache ()
234
+
229
235
@classmethod
230
236
def from_model (
231
237
cls ,
@@ -543,7 +549,7 @@ def greedy_until(
543
549
longest_context_continuation_size_in_split , self .max_length
544
550
)
545
551
batch_size = self ._get_batch_size (
546
- override_bs = self .batch_size ,
552
+ override_bs = self .config . batch_size ,
547
553
max_input_length = max_context_continuation_size_allowed ,
548
554
starting_batch_size = starting_batch_size ,
549
555
)
@@ -710,7 +716,6 @@ def _generate(
710
716
def loglikelihood (
711
717
self ,
712
718
requests : list [LoglikelihoodRequest ],
713
- override_bs : Optional [int ] = None ,
714
719
) -> list [LoglikelihoodResponse ]:
715
720
"""Tokenize the context and continuation and compute the log likelihood of those
716
721
tokenized sequences.
@@ -731,12 +736,11 @@ def loglikelihood(
731
736
request .context , request .choice , pairwise = self .pairwise_tokenization
732
737
)
733
738
734
- return self ._loglikelihood_tokens (requests , override_bs = override_bs )
739
+ return self ._loglikelihood_tokens (requests )
735
740
736
741
def loglikelihood_rolling (
737
742
self ,
738
743
requests : list [LoglikelihoodRollingRequest ],
739
- override_bs = None ,
740
744
) -> list [LoglikelihoodResponse ]:
741
745
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
742
746
@@ -746,7 +750,6 @@ def loglikelihood_rolling(
746
750
747
751
results = self ._loglikelihood_tokens (
748
752
requests ,
749
- override_bs = override_bs ,
750
753
return_bool_score = False ,
751
754
rolling = True ,
752
755
)
@@ -755,7 +758,6 @@ def loglikelihood_rolling(
755
758
def _loglikelihood_tokens (
756
759
self ,
757
760
requests : list [LoglikelihoodRequest ],
758
- override_bs : int = - 1 ,
759
761
return_bool_score : bool = True ,
760
762
rolling : bool = False ,
761
763
) -> list [LoglikelihoodResponse ]:
@@ -774,7 +776,7 @@ def _loglikelihood_tokens(
774
776
)
775
777
776
778
batch_size = self ._get_batch_size (
777
- override_bs = override_bs ,
779
+ override_bs = self . config . batch_size ,
778
780
max_input_length = max_context_continuation_size_allowed ,
779
781
starting_batch_size = starting_batch_size ,
780
782
)
@@ -967,7 +969,8 @@ def pad_and_gather(
967
969
return output_tensor , length_tensor
968
970
969
971
def loglikelihood_single_token (
970
- self , requests : list [LoglikelihoodSingleTokenRequest ], override_bs : Optional [int ] = None
972
+ self ,
973
+ requests : list [LoglikelihoodSingleTokenRequest ],
971
974
) -> list [LoglikelihoodSingleTokenResponse ]:
972
975
"""Tokenize the context and continuation and compute the log likelihood of those
973
976
tokenized sequences.
@@ -996,10 +999,11 @@ def loglikelihood_single_token(
996
999
)
997
1000
request .tokenized_continuation = continuations_enc
998
1001
999
- return self ._loglikelihood_single_token (requests , override_bs = override_bs )
1002
+ return self ._loglikelihood_single_token (requests )
1000
1003
1001
1004
def _loglikelihood_single_token (
1002
- self , requests : list [LoglikelihoodSingleTokenRequest ], override_bs : int = - 1
1005
+ self ,
1006
+ requests : list [LoglikelihoodSingleTokenRequest ],
1003
1007
) -> list [LoglikelihoodSingleTokenResponse ]:
1004
1008
dataset = LoglikelihoodSingleTokenDataset (requests = requests , num_dataset_splits = self .DATASET_SPLITS )
1005
1009
starting_batch_size = STARTING_BATCH_SIZE
@@ -1008,7 +1012,7 @@ def _loglikelihood_single_token(
1008
1012
for split_start , split_end in tqdm (dataset .splits_start_end_iterator ()):
1009
1013
context_enc = dataset [0 ].tokenized_context
1010
1014
max_context = len (context_enc [- self .max_length :])
1011
- batch_size = self ._get_batch_size (override_bs = override_bs , max_input_length = max_context )
1015
+ batch_size = self ._get_batch_size (override_bs = self . config . batch_size , max_input_length = max_context )
1012
1016
starting_batch_size = batch_size * 2
1013
1017
1014
1018
dataloader = DataLoader (dataset , batch_size = starting_batch_size , collate_fn = lambda batch : batch )
0 commit comments