Skip to content

Commit 32158b7

Browse files
committed
Add GSM8K prompt and dataset support for math tasks
Introduces GSM8K prompt and dataset configuration files for grade school math problem evaluation. Updates evaluator.py to support GSM8K answer extraction and adjusts evaluation logic for numeric answers. Modifies config.yaml for new optimal parameters and documents GSM8K support in the README.
1 parent b838099 commit 32158b7

File tree

5 files changed

+78
-12
lines changed

5 files changed

+78
-12
lines changed

examples/llm_prompt_optimization/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ This optimizer works with any HuggingFace dataset. Included examples:
9595

9696
- **IMDB Sentiment**: `initial_prompt.txt` + `initial_prompt_dataset.yaml` (binary classification)
9797
- **Emotion**: `emotion_prompt.txt` + `emotion_prompt_dataset.yaml` (6-class, benchmark against DSPy)
98+
- **GSM8K**: `gsm8k_prompt.txt` + `gsm8k_prompt_dataset.yaml` (grade school math, DSPy achieves 97.1%)
9899

99100
### Creating New Tasks
100101

examples/llm_prompt_optimization/config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ llm:
1616
- name: "gemini-2.5-flash-lite" # Using Gemini 2.5 Flash Lite
1717
weight: 1.0
1818

19-
temperature: 0.4 # Optimal from experiments
19+
temperature: 0.8 # Optimal from experiments
2020
max_tokens: 16000 # Optimal context
2121
timeout: 150
2222
retries: 3
@@ -53,17 +53,17 @@ database:
5353

5454
# Selection parameters - Optimal ratios from testing
5555
elite_selection_ratio: 0.1 # 10% elite selection
56-
exploration_ratio: 0.3 # 30% exploration
57-
exploitation_ratio: 0.6 # 60% exploitation
56+
exploration_ratio: 0.5 # 30% exploration
57+
exploitation_ratio: 0.4 # 60% exploitation
5858

5959
# Migration parameters - Optimal settings
6060
migration_interval: 10
6161
migration_rate: 0.1
6262

6363
# Evaluator Configuration
6464
evaluator:
65-
timeout: 200
65+
timeout: 600
6666
max_retries: 3
6767
parallel_evaluations: 4
6868
cascade_evaluation: true # Two-stage cascading evaluation
69-
cascade_thresholds: [0.9] # Stage 1 must achieve 90% accuracy to proceed to stage 2
69+
cascade_thresholds: [0.4] # Stage 1 must achieve 90% accuracy to proceed to stage 2

examples/llm_prompt_optimization/evaluator.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,24 @@ def load_prompt_config(prompt_path):
6969
def load_hf_dataset(config):
7070
"""Load HuggingFace dataset based on configuration."""
7171
dataset_name = config['dataset_name']
72+
dataset_config = config.get('dataset_config', None)
7273
split = config.get('split', 'test')
7374

7475
print(f"Loading dataset: {dataset_name}")
7576

7677
try:
7778
# Try to load the specified split
78-
dataset = load_dataset(dataset_name, split=split)
79+
if dataset_config:
80+
dataset = load_dataset(dataset_name, dataset_config, split=split)
81+
else:
82+
dataset = load_dataset(dataset_name, split=split)
7983
except:
8084
# Fallback to train split if test is not available
8185
print(f"Split '{split}' not found, falling back to 'train'")
82-
dataset = load_dataset(dataset_name, split='train')
86+
if dataset_config:
87+
dataset = load_dataset(dataset_name, dataset_config, split='train')
88+
else:
89+
dataset = load_dataset(dataset_name, split='train')
8390

8491
print(f"Dataset loaded with {len(dataset)} examples")
8592
return dataset
@@ -89,8 +96,10 @@ def evaluate_prompt(prompt, dataset, config, num_samples):
8996
input_field = config['input_field']
9097
target_field = config['target_field']
9198

92-
# Check if this is emotion classification (0-5) or sentiment (0-1)
93-
is_emotion = 'emotion' in config.get('dataset_name', '').lower()
99+
# Check dataset type
100+
dataset_name = config.get('dataset_name', '').lower()
101+
is_emotion = 'emotion' in dataset_name
102+
is_gsm8k = 'gsm8k' in dataset_name
94103

95104
# Sample from dataset
96105
samples = dataset.select(range(min(num_samples, len(dataset))))
@@ -110,11 +119,14 @@ def evaluate_prompt(prompt, dataset, config, num_samples):
110119
# Call the LLM with retry logic
111120
for attempt in range(MAX_RETRIES):
112121
try:
122+
# Adjust max_tokens based on task
123+
max_tokens = 500 if is_gsm8k else 20
124+
113125
response = test_model.chat.completions.create(
114126
model=TASK_MODEL_NAME,
115127
messages=messages,
116-
temperature=0.1, # Low temperature for consistent classification
117-
max_tokens=20 # Allow slightly more tokens for emotion labels
128+
temperature=0.1, # Low temperature for consistent results
129+
max_tokens=max_tokens
118130
)
119131
break
120132
except Exception as e:
@@ -150,7 +162,41 @@ def evaluate_prompt(prompt, dataset, config, num_samples):
150162

151163
# Extract prediction from output
152164
try:
153-
if is_emotion:
165+
if is_gsm8k:
166+
# For GSM8K, extract the numeric answer after ####
167+
# First, extract the expected answer from the ground truth
168+
expected_answer = expected.split('####')[-1].strip()
169+
try:
170+
expected_number = float(expected_answer.replace(',', ''))
171+
except:
172+
print(f"Warning: Could not parse expected answer: {expected_answer}")
173+
total += 1
174+
continue
175+
176+
# Extract prediction from model output
177+
prediction = None
178+
if '####' in output_text:
179+
predicted_answer = output_text.split('####')[-1].strip()
180+
# Extract just the number, removing any extra text like $ signs
181+
import re
182+
numbers = re.findall(r'-?\$?[\d,]+\.?\d*', predicted_answer)
183+
if numbers:
184+
try:
185+
# Remove $ and , from the number
186+
number_str = numbers[0].replace('$', '').replace(',', '')
187+
prediction = float(number_str)
188+
except:
189+
pass
190+
191+
# If we found a prediction, check if it matches
192+
if prediction is not None:
193+
# Check if answers match (with small tolerance for floats)
194+
if abs(prediction - expected_number) < 0.001:
195+
correct += 1
196+
197+
total += 1
198+
199+
elif is_emotion:
154200
# For emotion classification (0-5)
155201
numbers = re.findall(r'\b[0-5]\b', output_text)
156202
if numbers:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Solve the following grade school math problem step by step.
2+
3+
Problem: {input_text}
4+
5+
Show your work and reasoning for each step. After solving, provide your final numeric answer after "####".
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# HuggingFace dataset configuration for GSM8K (Grade School Math)
2+
# DSPy achieved 97.1% accuracy with GPT-4 on this benchmark
3+
dataset_name: "openai/gsm8k"
4+
dataset_config: "main" # GSM8K requires config name
5+
input_field: "question"
6+
target_field: "answer" # Contains step-by-step solution ending with #### followed by the numeric answer
7+
split: "test"
8+
9+
# Evaluation samples
10+
max_samples: 200 # Start with subset, full test set has 1,319 problems
11+
12+
# Note: The answer field contains the full solution with the format:
13+
# "Step 1 explanation... Step 2... #### numeric_answer"
14+
# The evaluator will need to extract the number after ####

0 commit comments

Comments
 (0)