Skip to content

Commit 491a3a7

Browse files
Small fixes to InferenceEndpointModel (#112)
- Added implementation for missing properties in InferenceEndpointModel - Added tokenized context in greedy generate + handled stop_sequence of tuple/list/str - Santized endpoint model name (to extend later) - Redid disable_tqdm & fixed call to batch generate with logits - Removed debug flag & swapped debug to true - Added get original order for InferenceEndpoint calls
1 parent b1aa626 commit 491a3a7

File tree

4 files changed

+25
-11
lines changed

4 files changed

+25
-11
lines changed

src/lighteval/main_accelerate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def main(args):
142142

143143
print(make_results_table(final_dict))
144144

145-
model.cleanup()
145+
if not args.reuse_existing:
146+
model.cleanup()
146147

147148
return final_dict

src/lighteval/models/endpoint_model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,23 @@ def __init__(
109109
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
110110
self.client = InferenceClient(model=config.model, token=env_config.token)
111111

112-
self.use_async = False # for debug - async use is faster
112+
self.use_async = True # set to False for debug - async use is faster
113113

114114
self._tokenizer = AutoTokenizer.from_pretrained(self.name)
115+
self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
115116

116117
@property
117118
def tokenizer(self):
118119
return self._tokenizer
119120

121+
@property
122+
def add_special_tokens(self):
123+
return self._add_special_tokens
124+
125+
@property
126+
def disable_tqdm(self) -> bool:
127+
False # no accelerator = this is the main process
128+
120129
def cleanup(self):
121130
if self.endpoint is not None:
122131
self.endpoint.delete()
@@ -250,7 +259,8 @@ def greedy_until(
250259
override_bs: Optional[int] = None,
251260
) -> List[GenerateReturn]:
252261
for request in requests:
253-
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
262+
request.tokenized_context = self.tok_encode(request.context)
263+
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
254264

255265
dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
256266
batch_size = override_bs if override_bs is not None else BATCH_SIZE
@@ -268,10 +278,11 @@ def greedy_until(
268278
for batch in tqdm(
269279
dataloader, desc="Greedy generation", position=1, leave=False, disable=self.disable_tqdm
270280
):
281+
# the `returns_logits` flag is only used to filter the results, we always request the full details.
271282
if self.use_async:
272-
responses = asyncio.run(self.__async_process_batch_generate(batch, returns_logits))
283+
responses = asyncio.run(self.__async_process_batch_generate(batch))
273284
else:
274-
responses = self.__process_batch_generate(batch, returns_logits)
285+
responses = self.__process_batch_generate(batch)
275286
for response in responses:
276287
results.append(
277288
GenerateReturn(
@@ -282,7 +293,7 @@ def greedy_until(
282293
)
283294
)
284295

285-
return results
296+
return dataset.get_original_order(results)
286297

287298
def loglikelihood(
288299
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
@@ -321,7 +332,7 @@ def loglikelihood(
321332
)
322333
)
323334

324-
return results
335+
return dataset.get_original_order(results)
325336

326337
def loglikelihood_rolling(
327338
self, requests: list[LoglikelihoodRollingRequest], override_bs=None
@@ -361,7 +372,7 @@ def loglikelihood_rolling(
361372
)
362373
)
363374

364-
return results
375+
return dataset.get_original_order(results)
365376

366377
def loglikelihood_single_token(
367378
self,

src/lighteval/models/model_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class TGIModelConfig:
221221
@dataclass
222222
class InferenceModelConfig:
223223
model: str
224+
add_special_tokens: bool = True
224225

225226

226227
@dataclass
@@ -235,6 +236,7 @@ class InferenceEndpointModelConfig:
235236
framework: str = "pytorch"
236237
endpoint_type: str = "protected"
237238
should_reuse_existing: bool = False
239+
add_special_tokens: bool = True
238240

239241

240242
def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
@@ -270,7 +272,7 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
270272
# Endpoint
271273
if args.endpoint_model_name:
272274
if args.reuse_existing or args.vendor is not None:
273-
model = args.endpoint_model_name.split("/")[1].lower()
275+
model = args.endpoint_model_name.split("/")[1].replace(".", "-").lower()
274276
return InferenceEndpointModelConfig(
275277
name=f"{model}-lighteval",
276278
repository=args.endpoint_model_name,

src/lighteval/tasks/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class GreedyUntilRequest(Request):
114114
request_type (RequestType): The type of the request, set to RequestType.GREEDY_UNTIL.
115115
"""
116116

117-
stop_sequence: str
117+
stop_sequence: Union[str, tuple[str], list[str]]
118118
generation_size: int
119119
request_type = RequestType.GREEDY_UNTIL
120120
tokenized_context: list[int] = None
@@ -132,7 +132,7 @@ class GreedyUntilWithLogitsRequest(Request):
132132
request_type (RequestType): The type of the request (GREEDY_UNTIL_WITH_LOGITS).
133133
"""
134134

135-
stop_sequence: str
135+
stop_sequence: Union[str, tuple[str], list[str]]
136136
generation_size: int
137137
request_type = RequestType.GREEDY_UNTIL_WITH_LOGITS
138138
tokenized_context: list[int] = None

0 commit comments

Comments
 (0)