Skip to content

Commit 08e374b

Browse files
committed
added quality
1 parent 11db56e commit 08e374b

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

task.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import numpy as np
21
from abc import ABC, abstractmethod
32
from string import ascii_uppercase
3+
4+
import numpy as np
45
from datasets import load_dataset
6+
57
from metric import AutoMetric
68

79

@@ -409,13 +411,80 @@ def _process_logits(self, logits, split):
409411
return preds
410412

411413

414+
class ScrollsQuality(LogitEvaluationTask):
415+
"""
416+
Evaluation dataset derived from `tau/scrolls`.
417+
It is processed into a suitable format here: https://huggingface.co/datasets/rbiswasfc/quality.
418+
Test split doesn't have ground truths, hence it will use validation split as an alternative.
419+
"""
420+
421+
DEFAULT_PROMPT_TEMPLATE = """You will be given a context, a question related to that context, and four possible answer choices. Carefully read the context, question, and answer choices, then select the best answer.
422+
IMPORTANT: Provide only the letter corresponding to your chosen answer. Do not write out the full answer or give any explanation.
423+
424+
Context:
425+
{context}
426+
427+
Question:
428+
{question}
429+
430+
Answer choices:
431+
{choices}
432+
433+
Answer:
434+
"""
435+
436+
def __init__(
437+
self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=128, **kwargs
438+
):
439+
super().__init__(
440+
prompt_template, max_tokens, hf_args=["rbiswasfc/quality"], **kwargs
441+
)
442+
443+
self.metrics = {
444+
"Accuracy": AutoMetric.from_name("accuracy"),
445+
}
446+
self.test_split = "validation" # Test split doesn't have ground truths - use validation split
447+
448+
self.mandatory_cols.append("num_choices")
449+
450+
def prepare_row(self, row: dict):
451+
context = row["context"]
452+
question = row["question"]
453+
choices = row["choices"]
454+
num_choices = len(choices)
455+
answer = ascii_uppercase[row["label"]]
456+
457+
choices = "\n".join(
458+
[f"{char}. {opt}" for char, opt in zip(ascii_uppercase, choices)]
459+
)
460+
461+
return {
462+
"context": context,
463+
"question": question,
464+
"prompt": self.prompt_template.format(
465+
context=context, question=question, choices=choices
466+
),
467+
"labels": answer,
468+
"num_choices": num_choices,
469+
}
470+
471+
def _process_logits(self, logits, split):
472+
preds = []
473+
for l, nc in zip(logits, self.get_split(split)["num_choices"]):
474+
pred = [l[ascii_uppercase[i]] for i in range(nc)]
475+
preds.append(ascii_uppercase[np.argmax(pred)])
476+
477+
return preds
478+
479+
412480
TASK_MAPPING = {
413481
"squality": Squality,
414482
"triviaqa": TriviaQA,
415483
"dolomites": Dolomites,
416484
"qmsum": QMSum,
417485
"musique": Musique,
418486
"truthfulqa": TruthfulQA,
487+
"scrollsquality": ScrollsQuality,
419488
}
420489

421490

0 commit comments

Comments
 (0)