56
56
)
57
57
from lighteval .utils .imports import is_nanotron_available
58
58
from lighteval .utils .parallelism import find_executable_batch_size
59
- from lighteval .utils .utils import EnvConfig , as_list
59
+ from lighteval .utils .utils import as_list
60
60
61
61
62
62
logger = logging .getLogger (__name__ )
@@ -101,7 +101,6 @@ def __init__(
101
101
trust_remote_code : bool = False ,
102
102
debug_one_layer_model : bool = False ,
103
103
model_class : Optional [Type ] = None ,
104
- env_config : EnvConfig = None ,
105
104
):
106
105
"""Initializes a nanotron model for evaluation.
107
106
Args:
@@ -138,7 +137,6 @@ def __init__(
138
137
self ._add_special_tokens = add_special_tokens
139
138
self ._tokenizer = self ._create_auto_tokenizer (
140
139
pretrained = tokenizer .tokenizer_name_or_path ,
141
- env_config = env_config ,
142
140
trust_remote_code = trust_remote_code ,
143
141
)
144
142
self ._tokenizer .model_max_length = self .max_length
@@ -230,23 +228,18 @@ def _create_auto_tokenizer(
230
228
* ,
231
229
pretrained : str ,
232
230
tokenizer : Optional [str ] = None ,
233
- env_config : EnvConfig = None ,
234
231
trust_remote_code : bool = False ,
235
232
) -> transformers .PreTrainedTokenizer :
236
233
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
237
234
238
235
try :
239
236
tokenizer = AutoTokenizer .from_pretrained (
240
237
pretrained if tokenizer is None else tokenizer ,
241
- cache_dir = env_config .cache_dir ,
242
- token = env_config .token ,
243
238
trust_remote_code = trust_remote_code ,
244
239
)
245
240
except RecursionError :
246
241
tokenizer = AutoTokenizer .from_pretrained (
247
242
pretrained if tokenizer is None else tokenizer ,
248
- cache_dir = env_config .cache_dir ,
249
- token = env_config .token ,
250
243
unk_token = "<unk>" ,
251
244
trust_remote_code = trust_remote_code ,
252
245
)
@@ -711,14 +704,14 @@ def _loglikelihood_single_token(
711
704
inputs , padding_length = max_context , max_context = max_context , full_attention_masks = True
712
705
)
713
706
# batched_inputs, batch_attention, input_lengths, truncated, padded
714
-
715
- out = self .model (input_ids = batch_model .input_ids , input_mask = batch_model . input_mask )
707
+ position_ids = torch . arange ( batch_model . input_ids . shape [ 1 ], device = self . device , dtype = torch . int32 ). unsqueeze ( 0 ). repeat ( batch_model . input_ids . shape [ 0 ], 1 )
708
+ out = self .model (input_ids = batch_model .input_ids , position_ids = position_ids )
716
709
717
710
if dist .get_rank (self .parallel_context .pp_pg ) == self .output_pp_rank :
718
711
# This process got outputs
719
712
720
- # Gather all the output across TP
721
- out = out .transpose ( 0 , 1 ).contiguous () # [batch, seq_length, vocab]
713
+ # Gather all the output accross TP
714
+ out = out .view ( * batch_model . input_ids . shape , - 1 ).contiguous () # [batch, seq_length, vocab]
722
715
723
716
gathered_out = [torch .zeros_like (out ) for _ in range (self .parallel_context .tp_pg .size ())]
724
717
dist .all_gather (gathered_out , out , group = self .parallel_context .tp_pg , async_op = False )
@@ -944,7 +937,8 @@ def _loglikelihood_tokens(
944
937
)
945
938
# batched_inputs, batch_attention, input_lengths, truncated, padded
946
939
with torch .no_grad ():
947
- out = self .model (input_ids = batch_model .input_ids , input_mask = batch_model .input_mask )
940
+ position_ids = torch .arange (batch_model .input_ids .shape [1 ], device = self .device , dtype = torch .int32 ).unsqueeze (0 ).repeat (batch_model .input_ids .shape [0 ], 1 )
941
+ out = self .model (input_ids = batch_model .input_ids , position_ids = position_ids )
948
942
949
943
if dist .get_rank (self .parallel_context .pp_pg ) == self .output_pp_rank :
950
944
# This process got outputs
@@ -954,7 +948,7 @@ def _loglikelihood_tokens(
954
948
dist .all_gather (gathered_out , out , group = self .parallel_context .tp_pg , async_op = False )
955
949
out = torch .cat (gathered_out , dim = - 1 )
956
950
957
- out = out .transpose ( 0 , 1 ) # [batch, seq_length, vocab]
951
+ out = out .view ( * batch_model . input_ids . shape , - 1 ) # [batch, seq_length, vocab]
958
952
multi_logits = F .log_softmax (out , dim = - 1 ) # [batch, padding_length, vocab]
959
953
960
954
logits_sum = []
0 commit comments