|
1 |
| -import numpy as np |
2 | 1 | from abc import ABC, abstractmethod
|
3 | 2 | from string import ascii_uppercase
|
| 3 | + |
| 4 | +import numpy as np |
4 | 5 | from datasets import load_dataset
|
| 6 | + |
5 | 7 | from metric import AutoMetric
|
6 | 8 |
|
7 | 9 |
|
@@ -409,13 +411,80 @@ def _process_logits(self, logits, split):
|
409 | 411 | return preds
|
410 | 412 |
|
411 | 413 |
|
| 414 | +class ScrollsQuality(LogitEvaluationTask): |
| 415 | + """ |
| 416 | + Evaluation dataset derived from `tau/scrolls`. |
| 417 | + It is processed into a suitable format here: https://huggingface.co/datasets/rbiswasfc/quality. |
| 418 | + Test split doesn't have ground truths, hence it will use validation split as an alternative. |
| 419 | + """ |
| 420 | + |
| 421 | + DEFAULT_PROMPT_TEMPLATE = """You will be given a context, a question related to that context, and four possible answer choices. Carefully read the context, question, and answer choices, then select the best answer. |
| 422 | +IMPORTANT: Provide only the letter corresponding to your chosen answer. Do not write out the full answer or give any explanation. |
| 423 | +
|
| 424 | +Context: |
| 425 | +{context} |
| 426 | +
|
| 427 | +Question: |
| 428 | +{question} |
| 429 | +
|
| 430 | +Answer choices: |
| 431 | +{choices} |
| 432 | +
|
| 433 | +Answer: |
| 434 | +""" |
| 435 | + |
| 436 | + def __init__( |
| 437 | + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=128, **kwargs |
| 438 | + ): |
| 439 | + super().__init__( |
| 440 | + prompt_template, max_tokens, hf_args=["rbiswasfc/quality"], **kwargs |
| 441 | + ) |
| 442 | + |
| 443 | + self.metrics = { |
| 444 | + "Accuracy": AutoMetric.from_name("accuracy"), |
| 445 | + } |
| 446 | + self.test_split = "validation" # Test split doesn't have ground truths - use validation split |
| 447 | + |
| 448 | + self.mandatory_cols.append("num_choices") |
| 449 | + |
| 450 | + def prepare_row(self, row: dict): |
| 451 | + context = row["context"] |
| 452 | + question = row["question"] |
| 453 | + choices = row["choices"] |
| 454 | + num_choices = len(choices) |
| 455 | + answer = ascii_uppercase[row["label"]] |
| 456 | + |
| 457 | + choices = "\n".join( |
| 458 | + [f"{char}. {opt}" for char, opt in zip(ascii_uppercase, choices)] |
| 459 | + ) |
| 460 | + |
| 461 | + return { |
| 462 | + "context": context, |
| 463 | + "question": question, |
| 464 | + "prompt": self.prompt_template.format( |
| 465 | + context=context, question=question, choices=choices |
| 466 | + ), |
| 467 | + "labels": answer, |
| 468 | + "num_choices": num_choices, |
| 469 | + } |
| 470 | + |
| 471 | + def _process_logits(self, logits, split): |
| 472 | + preds = [] |
| 473 | + for l, nc in zip(logits, self.get_split(split)["num_choices"]): |
| 474 | + pred = [l[ascii_uppercase[i]] for i in range(nc)] |
| 475 | + preds.append(ascii_uppercase[np.argmax(pred)]) |
| 476 | + |
| 477 | + return preds |
| 478 | + |
| 479 | + |
412 | 480 | TASK_MAPPING = {
|
413 | 481 | "squality": Squality,
|
414 | 482 | "triviaqa": TriviaQA,
|
415 | 483 | "dolomites": Dolomites,
|
416 | 484 | "qmsum": QMSum,
|
417 | 485 | "musique": Musique,
|
418 | 486 | "truthfulqa": TruthfulQA,
|
| 487 | + "scrollsquality": ScrollsQuality, |
419 | 488 | }
|
420 | 489 |
|
421 | 490 |
|
|
0 commit comments