@@ -109,14 +109,23 @@ def __init__(
109109 self .async_client = AsyncInferenceClient (model = config .model , token = env_config .token )
110110 self .client = InferenceClient (model = config .model , token = env_config .token )
111111
112- self .use_async = False # for debug - async use is faster
112+ self .use_async = True # set to False for debug - async use is faster
113113
114114 self ._tokenizer = AutoTokenizer .from_pretrained (self .name )
115+ self ._add_special_tokens = config .add_special_tokens if config .add_special_tokens is not None else False
115116
116117 @property
117118 def tokenizer (self ):
118119 return self ._tokenizer
119120
121+ @property
122+ def add_special_tokens (self ):
123+ return self ._add_special_tokens
124+
125+ @property
126+ def disable_tqdm (self ) -> bool :
127+ False # no accelerator = this is the main process
128+
120129 def cleanup (self ):
121130 if self .endpoint is not None :
122131 self .endpoint .delete ()
@@ -250,7 +259,8 @@ def greedy_until(
250259 override_bs : Optional [int ] = None ,
251260 ) -> List [GenerateReturn ]:
252261 for request in requests :
253- request .stop_sequence = request .stop_sequence + [self .tokenizer .eos_token ]
262+ request .tokenized_context = self .tok_encode (request .context )
263+ request .stop_sequence = as_list (request .stop_sequence ) + [self .tokenizer .eos_token ]
254264
255265 dataset = GenerativeTaskDataset (requests = requests , dataset_splits = self .DATASET_SPLITS )
256266 batch_size = override_bs if override_bs is not None else BATCH_SIZE
@@ -268,10 +278,11 @@ def greedy_until(
268278 for batch in tqdm (
269279 dataloader , desc = "Greedy generation" , position = 1 , leave = False , disable = self .disable_tqdm
270280 ):
281+ # the `returns_logits` flag is only used to filter the results, we always request the full details.
271282 if self .use_async :
272- responses = asyncio .run (self .__async_process_batch_generate (batch , returns_logits ))
283+ responses = asyncio .run (self .__async_process_batch_generate (batch ))
273284 else :
274- responses = self .__process_batch_generate (batch , returns_logits )
285+ responses = self .__process_batch_generate (batch )
275286 for response in responses :
276287 results .append (
277288 GenerateReturn (
@@ -282,7 +293,7 @@ def greedy_until(
282293 )
283294 )
284295
285- return results
296+ return dataset . get_original_order ( results )
286297
287298 def loglikelihood (
288299 self , requests : list [LoglikelihoodRequest ], override_bs : Optional [int ] = None
@@ -321,7 +332,7 @@ def loglikelihood(
321332 )
322333 )
323334
324- return results
335+ return dataset . get_original_order ( results )
325336
326337 def loglikelihood_rolling (
327338 self , requests : list [LoglikelihoodRollingRequest ], override_bs = None
@@ -361,7 +372,7 @@ def loglikelihood_rolling(
361372 )
362373 )
363374
364- return results
375+ return dataset . get_original_order ( results )
365376
366377 def loglikelihood_single_token (
367378 self ,
0 commit comments