Skip to content

Commit ab8bfd9

Browse files
committed
Add PubMedQA dataset
1 parent bccc277 commit ab8bfd9

File tree

2 files changed

+261
-0
lines changed

2 files changed

+261
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""Load and process the PubMedQA dataset.
2+
3+
Dataset: HuggingFace `qiaojin/PubMedQA` dataset.
4+
Each example is normalized to the following fields:
5+
{
6+
"question": "<formatted question with context>", # complete prompt with abstract
7+
"answer": "<A|B|C>", # A=yes, B=no, C=maybe
8+
"info": { ...original example fields... } # full source row for debugging
9+
}
10+
"""
11+
12+
import json
13+
import os
14+
from typing import Any
15+
16+
from datasets import load_dataset
17+
18+
19+
class PubMedQADataset:
20+
"""Process the PubMedQA dataset."""
21+
22+
def __init__(
23+
self,
24+
num_train_examples: int = -1,
25+
num_test_examples: int = -1,
26+
):
27+
"""Initialize the PubMedQA dataset processor.
28+
29+
Args:
30+
num_train_examples: Number of training examples to use (-1 for all)
31+
num_test_examples: Number of test examples to use (-1 for all)
32+
"""
33+
self.num_train_examples = num_train_examples
34+
self.num_test_examples = num_test_examples
35+
self.rng_seed = 12345
36+
self.dataset_path = "qiaojin/PubMedQA"
37+
38+
# Load and process datasets on initialization
39+
self.train_ds, self.test_ds = self._load_and_process_datasets()
40+
41+
def _load_and_process_datasets(self) -> tuple:
42+
"""Load and process the PubMedQA datasets."""
43+
# Load the raw datasets
44+
# pqa_artificial is the training set, pqa_labeled is the test set
45+
train_raw = load_dataset(
46+
self.dataset_path, name="pqa_artificial", split="train"
47+
)
48+
test_raw = load_dataset(self.dataset_path, name="pqa_labeled", split="train")
49+
50+
# Filter test set to only include human-annotated samples
51+
test_raw = self._filter_test_set(test_raw)
52+
53+
# Limit number of examples if specified
54+
if self.num_train_examples != -1:
55+
train_raw = train_raw.select(
56+
range(min(self.num_train_examples, len(train_raw)))
57+
)
58+
if self.num_test_examples != -1:
59+
test_raw = test_raw.select(
60+
range(min(self.num_test_examples, len(test_raw)))
61+
)
62+
63+
# Format datasets
64+
train_formatted = self._format_dataset(train_raw, "train")
65+
test_formatted = self._format_dataset(test_raw, "test")
66+
67+
# Shuffle datasets
68+
train_formatted = train_formatted.shuffle(seed=self.rng_seed)
69+
test_formatted = test_formatted.shuffle(seed=self.rng_seed)
70+
71+
return train_formatted, test_formatted
72+
73+
def _filter_test_set(self, dataset: Any) -> Any:
74+
"""Filter test set to only include human-annotated samples (500 from 1000)."""
75+
# Load the predefined test IDs
76+
here = os.path.dirname(__file__)
77+
file_path = os.path.join(here, "data", "test_ground_truth.json")
78+
79+
try:
80+
with open(file_path) as f:
81+
test_ids = json.load(f)
82+
83+
# Filter to only the 500 human-annotated samples
84+
return dataset.filter(lambda sample: str(sample["pubid"]) in test_ids)
85+
except FileNotFoundError:
86+
# If the file doesn't exist, return the full test set
87+
print(f"Warning: {file_path} not found. Using full test set.")
88+
return dataset
89+
90+
def _format_dataset(self, dataset: Any, split: str) -> Any:
91+
"""Format dataset with question, answer, and info fields."""
92+
choices_map = {"yes": "A", "no": "B", "maybe": "C"}
93+
prompt_template = "Answer A for yes, B for no or C for maybe.\n\nContext: {context}\n\nQuestion: {question}\nAnswer:"
94+
95+
def format_row(row: dict) -> dict:
96+
row = dict(row)
97+
98+
# Extract question
99+
question_text = row.get("question", "") or ""
100+
101+
# Extract and format context
102+
context_dict = row.get("context", {}) or {}
103+
labels = context_dict.get("labels", []) or []
104+
contexts = context_dict.get("contexts", []) or []
105+
106+
# Format contexts with their labels
107+
formatted_contexts = []
108+
for label, context in zip(labels, contexts):
109+
formatted_contexts.append(f"{label}. {context}")
110+
context_text = "\n".join(formatted_contexts)
111+
112+
# Build complete prompt
113+
complete_prompt = prompt_template.format(
114+
context=context_text, question=question_text
115+
)
116+
117+
# Map final decision to letter (A/B/C)
118+
final_decision = (row.get("final_decision", "") or "").lower()
119+
answer = choices_map.get(final_decision, "")
120+
121+
# Keep full original example under 'info'
122+
info = dict(row)
123+
124+
return {
125+
"question": complete_prompt,
126+
"answer": answer,
127+
"info": info,
128+
}
129+
130+
return dataset.map(format_row, load_from_cache_file=False)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""PubMedQA Evaluation.
2+
3+
Dataset: HuggingFace `qiaojin/PubMedQA` (pqa_labeled and pqa_artificial splits).
4+
5+
- Parser: Extracts \\boxed{A|B|C} from completions (A=yes, B=no, C=maybe)
6+
- Reward Functions:
7+
- Exact match classification reward
8+
"""
9+
10+
import json
11+
import os
12+
13+
import verifiers as vf
14+
from datasets import load_dataset
15+
from dotenv import load_dotenv
16+
from openai import OpenAI
17+
18+
19+
def map_row_to_mcq_prompt(row):
20+
"""Map PubMedQA row to MCQ-style prompt with A/B/C answers."""
21+
question_text = row.get("question", "")
22+
context_dict = row.get("context", {})
23+
labels = context_dict.get("labels", [])
24+
contexts = context_dict.get("contexts", [])
25+
final_decision = row.get("final_decision", "").lower()
26+
27+
choices_map = {"yes": "A", "no": "B", "maybe": "C"}
28+
correct_answer_letter = choices_map.get(final_decision, "C") # default to maybe
29+
30+
formatted_contexts = [
31+
f"{label}. {context}" for label, context in zip(labels, contexts)
32+
]
33+
context_text = "\n".join(formatted_contexts)
34+
35+
complete_prompt = (
36+
f"Answer A for yes, B for no or C for maybe.\n\n"
37+
f"Context: {context_text}\n\n"
38+
f"Question: {question_text}\nAnswer: "
39+
)
40+
41+
return {
42+
"question": complete_prompt,
43+
"answer": correct_answer_letter,
44+
"task": "pubmedqa",
45+
}
46+
47+
48+
def classification_reward_func(prompt, completion, answer, state, **kwargs) -> float:
49+
"""Exact match reward: 1.0 if predicted letter matches ground truth."""
50+
# Extract content from chat completion
51+
if isinstance(completion, list) and len(completion) > 0:
52+
content = completion[0].get("content", "")
53+
else:
54+
content = str(completion)
55+
56+
# Parse using the rubric's parser
57+
parser = kwargs.get("parser")
58+
if parser is None:
59+
return 0.0
60+
61+
parsed = parser.parse(content)
62+
predicted_letter = parsed.strip().rstrip(".") if parsed else None
63+
64+
return 1.0 if predicted_letter == answer else 0.0
65+
66+
67+
def main() -> None:
68+
"""Run evaluation on PubMedQA."""
69+
load_dotenv()
70+
71+
# Load datasets
72+
DATASET_PATH = "qiaojin/PubMedQA"
73+
dataset_train = load_dataset(DATASET_PATH, name="pqa_artificial", split="train")
74+
dataset_test = load_dataset(DATASET_PATH, name="pqa_labeled", split="train")
75+
76+
# Filter test set to human-annotated 500 examples
77+
here = os.path.dirname(__file__)
78+
file_path = os.path.join(here, "data", "test_ground_truth.json")
79+
with open(file_path) as f:
80+
test_ids = set(json.load(f)) # use set for O(1) lookup
81+
82+
dataset_test = dataset_test.filter(
83+
lambda sample: str(sample["pubid"]) in test_ids, load_from_cache_file=False
84+
)
85+
86+
# Map to standard format
87+
mapped_train = dataset_train.map(
88+
map_row_to_mcq_prompt, load_from_cache_file=False, keep_in_memory=True
89+
)
90+
mapped_test = dataset_test.map(
91+
map_row_to_mcq_prompt, load_from_cache_file=False, keep_in_memory=True
92+
)
93+
94+
# Use boxed-only system prompt (no chain-of-thought)
95+
system_prompt = vf.utils.data_utils.BOXED_SYSTEM_PROMPT
96+
parser = vf.parsers.parser.Parser(extract_fn=vf.extract_boxed_answer)
97+
98+
# Build rubric
99+
rubric = vf.Rubric(
100+
funcs=[classification_reward_func],
101+
weights=[1.0],
102+
parser=parser,
103+
)
104+
105+
# Create environment
106+
env = vf.SingleTurnEnv(
107+
dataset=mapped_train,
108+
eval_dataset=mapped_test,
109+
system_prompt=system_prompt,
110+
parser=parser,
111+
rubric=rubric,
112+
)
113+
114+
# Initialize client (Groq via OpenAI-compatible API)
115+
client = OpenAI(
116+
api_key=os.getenv("GROQ_API_KEY"),
117+
base_url="https://api.groq.com/openai/v1",
118+
)
119+
120+
# Run evaluation
121+
results = env.evaluate(
122+
client=client,
123+
model="llama-3.3-70b-versatile",
124+
num_examples=2,
125+
rollouts_per_example=5,
126+
)
127+
print(results)
128+
129+
130+
if __name__ == "__main__":
131+
main()

0 commit comments

Comments
 (0)