3434from tqdm import tqdm
3535from transformers import AutoTokenizer , BatchEncoding
3636
37+ from lighteval .config .lighteval_config import FullNanotronConfig
3738from lighteval .data import (
3839 GenDistributedSampler ,
3940 GenerativeTaskDatasetNanotron ,
5556)
5657from lighteval .utils .imports import is_nanotron_available
5758from lighteval .utils .parallelism import find_executable_batch_size
58- from lighteval .utils .utils import EnvConfig , as_list , boolstring_to_bool
59+ from lighteval .utils .utils import EnvConfig , as_list
5960
6061
6162os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
6263
6364TokenSequence = Union [List [int ], torch .LongTensor , torch .Tensor , BatchEncoding ]
6465
6566if is_nanotron_available ():
66- import nanotron
6767 from nanotron import distributed as dist
6868 from nanotron import logging
69- from nanotron .config import LightEvalConfig , ModelArgs , TokenizerArgs
7069 from nanotron .generation .decode import decode_tokenized
7170 from nanotron .logging import human_format , log_rank
7271 from nanotron .models import build_model
@@ -90,7 +89,7 @@ class NanotronLightevalModel(LightevalModel):
9089 def __init__ (
9190 self ,
9291 checkpoint_path : str ,
93- nanotron_config : nanotron . config . Config ,
92+ nanotron_config : FullNanotronConfig ,
9493 parallel_context : ParallelContext ,
9594 max_gen_toks : Optional [int ] = 256 ,
9695 max_length : Optional [int ] = None ,
@@ -104,12 +103,11 @@ def __init__(
104103 """Initializes a nanotron model for evaluation.
105104 Args:
106105 """
107- model_args : ModelArgs = nanotron_config .model
108- tokenizer : TokenizerArgs = nanotron_config .tokenizer
109- lighteval_config : LightEvalConfig = nanotron_config .lighteval
110- parallel_config : ParallelContext = nanotron_config .lighteval .parallelism
106+ model_args = nanotron_config . nanotron_config .model
107+ tokenizer = nanotron_config . nanotron_config .tokenizer
108+ lighteval_config = nanotron_config .lighteval_config
109+ parallel_config = nanotron_config .lighteval_config .parallelism
111110
112- self ._batch_size = lighteval_config .batch_size
113111 self ._max_gen_toks = max_gen_toks
114112 self ._max_length = max_length
115113 self .parallel_config = parallel_config
@@ -120,9 +118,7 @@ def __init__(
120118 raise ValueError ("PP parallelism is not supported yet" )
121119
122120 # multichoice_continuations_start_space can be True (forcing space), False (forcing no space) or None (no forcing)
123- multichoice_continuations_start_space = boolstring_to_bool (
124- lighteval_config .tasks .multichoice_continuations_start_space
125- )
121+ multichoice_continuations_start_space = lighteval_config .tasks .multichoice_continuations_start_space
126122
127123 self .generation_config = lighteval_config .generation
128124 if isinstance (self .generation_config , dict ):
@@ -217,7 +213,9 @@ def __init__(
217213
218214 self .multichoice_continuations_start_space = multichoice_continuations_start_space
219215
220- self .model_info = ModelInfo (model_name = f"{ nanotron_config .general .run } /{ nanotron_config .general .step } " )
216+ self .model_info = ModelInfo (
217+ model_name = f"{ nanotron_config .nanotron_config .general .run } /{ nanotron_config .nanotron_config .general .step } "
218+ )
221219
222220 @property
223221 def tokenizer (self ):
@@ -299,12 +297,6 @@ def max_length(self) -> int:
299297 return self .tokenizer .model_max_length
300298 return self ._DEFAULT_MAX_LENGTH
301299
302- @property
303- def batch_size (self ) -> int :
304- if self ._batch_size >= 0 :
305- self ._batch_size = self ._get_batch_size (max_input_length = self .max_length )
306- return self ._batch_size # * gpus
307-
308300 @property
309301 def device (self ) -> Union [int , str , torch .device ]:
310302 return "cuda"
@@ -415,7 +407,7 @@ def _check_continuations_start_space(self, continuation: str) -> str:
415407 return continuation
416408
417409 def loglikelihood_single_token (
418- self , requests : List [Tuple [str , dict ]], override_bs = None
410+ self , requests : List [Tuple [str , dict ]], override_bs = 0
419411 ) -> List [LoglikelihoodSingleTokenResponse ]:
420412 """Tokenize the context and continuation and compute the log likelihood of those
421413 tokenized sequences.
@@ -475,7 +467,7 @@ def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None)
475467 )
476468
477469 def loglikelihood_rolling (
478- self , requests : List [LoglikelihoodRollingRequest ], override_bs = None
470+ self , requests : List [LoglikelihoodRollingRequest ], override_bs : int = 0
479471 ) -> List [LoglikelihoodResponse ]:
480472 """This function is used to compute the log likelihood of the context for perplexity metrics."""
481473 for request in tqdm (
@@ -652,7 +644,7 @@ def _get_subsets(self, dataset, num_dataset_splits):
652644
653645 @torch .inference_mode ()
654646 def _loglikelihood_single_token (
655- self , requests , disable_tqdm : bool = False , override_bs : int = - 1 , num_dataset_splits : int = 1
647+ self , requests , disable_tqdm : bool = False , override_bs : int = 0 , num_dataset_splits : int = 1
656648 ) -> List [LoglikelihoodSingleTokenResponse ]:
657649 dataset = LoglikelihoodSingleTokenDataset (requests = requests )
658650 res = []
@@ -1115,7 +1107,7 @@ def greedy_until(
11151107 self ,
11161108 requests : List [GreedyUntilRequest ],
11171109 disable_tqdm : bool = False ,
1118- override_bs = None ,
1110+ override_bs : int = - 1 ,
11191111 num_dataset_splits : int = 1 ,
11201112 ) -> List [GenerativeResponse ]:
11211113 """Greedy generation until a stop token is generated."""
@@ -1155,7 +1147,7 @@ def greedy_until(
11551147 max_input_length = min (len (context_enc ) + max_gen , self .max_length )
11561148
11571149 batch_size = self ._get_batch_size (
1158- override_bs = self . _batch_size ,
1150+ override_bs = override_bs ,
11591151 max_input_length = max_input_length ,
11601152 starting_batch_size = starting_batch_size ,
11611153 )
0 commit comments