|
23 | 23 | import asyncio |
24 | 24 | from typing import Coroutine, List, Optional, Union |
25 | 25 |
|
| 26 | +import torch |
26 | 27 | from huggingface_hub import ( |
27 | 28 | AsyncInferenceClient, |
28 | 29 | InferenceClient, |
@@ -314,17 +315,22 @@ def loglikelihood( |
314 | 315 | ): |
315 | 316 | dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) |
316 | 317 |
|
317 | | - for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): |
| 318 | + for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm): |
318 | 319 | if self.use_async: |
319 | 320 | responses = asyncio.run(self.__async_process_batch_logprob(batch)) |
320 | 321 | else: |
321 | 322 | responses = self.__process_batch_logprob(batch) |
322 | | - for ix, response in enumerate(responses): |
323 | | - len_choice = len(batch[ix].tokenized_continuation) |
| 323 | + for cur_request, response in zip(batch, responses): |
| 324 | + cont_toks = torch.tensor(cur_request.tokenized_continuation) |
| 325 | + len_choice = len(cont_toks) |
| 326 | + |
324 | 327 | logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None] |
| 328 | + |
| 329 | + greedy_tokens = torch.tensor(logits).argmax(dim=-1) |
| 330 | + max_equal = (greedy_tokens == cont_toks).all().squeeze(0) |
325 | 331 | results.append( |
326 | 332 | LoglikelihoodReturn( |
327 | | - result=sum(logits), |
| 333 | + result=(sum(logits), bool(max_equal)), |
328 | 334 | input_tokens=[t.id for t in response.details.prefill[:-len_choice]], |
329 | 335 | generated_tokens=[t.id for t in response.details.prefill[-len_choice:]], |
330 | 336 | truncated_tokens_count=-1, |
@@ -355,13 +361,16 @@ def loglikelihood_rolling( |
355 | 361 | ): |
356 | 362 | dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) |
357 | 363 |
|
358 | | - for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): |
| 364 | + for batch in tqdm( |
| 365 | + dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm |
| 366 | + ): |
359 | 367 | if self.use_async: |
360 | 368 | responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True)) |
361 | 369 | else: |
362 | 370 | responses = self.__process_batch_logprob(batch, rolling=True) |
363 | 371 | for response in responses: |
364 | 372 | logits = [t.logprob for t in response.details.tokens[:-1]] |
| 373 | + |
365 | 374 | results.append( |
366 | 375 | LoglikelihoodReturn( |
367 | 376 | result=sum(logits), |
|
0 commit comments