Skip to content

Commit e309696

Browse files
committed
update script
1 parent d35672f commit e309696

File tree

1 file changed

+45
-14
lines changed

1 file changed

+45
-14
lines changed

agent/notebooks/convert_halted_questions.py renamed to agent/notebooks/convert_halted_unpublished_questions.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
"""
2+
The script converts halted unpublished questions into a format suitable for gm-eval to read.
3+
4+
The input CSV expected following columns:
5+
- Question
6+
- Answer options
7+
- Correct answer
8+
- Very Wrong answer
9+
"""
10+
111
import pandas as pd
212
import re
313
import random
414
import os
15+
import argparse
516

617
def clean_option_text(option_text):
718
"""Removes leading option markers like 'A. ', 'B. ' etc."""
@@ -108,9 +119,29 @@ def determine_correctness(options, correct_index, very_wrong_answer):
108119

109120

110121
def main():
111-
# Define input and output paths
112-
input_csv_path = 'halted_questions.csv'
113-
output_dir = './'
122+
# Set up command line argument parsing
123+
parser = argparse.ArgumentParser(description='Convert halted questions CSV to questions and options CSV files.')
124+
parser.add_argument('-i', '--input',
125+
default='halted_questions.csv',
126+
help='Path to input CSV file (default: halted_questions.csv)')
127+
parser.add_argument('-o', '--output-dir',
128+
default='./',
129+
help='Output directory for generated CSV files (default: current directory)')
130+
parser.add_argument('-s', '--start-index',
131+
type=int,
132+
default=0,
133+
help='Starting index for question IDs (default: 0)')
134+
135+
args = parser.parse_args()
136+
137+
# Use command line arguments
138+
input_csv_path = args.input
139+
output_dir = args.output_dir
140+
start_index = args.start_index
141+
142+
# Ensure output directory exists
143+
os.makedirs(output_dir, exist_ok=True)
144+
114145
output_questions_path = os.path.join(output_dir, 'questions.csv')
115146
output_options_path = os.path.join(output_dir, 'question_options.csv')
116147

@@ -126,17 +157,17 @@ def main():
126157
# The CSV has multiple initial rows that are not the true header.
127158
# We'll look for a row that contains expected column names.
128159
df_full = pd.read_csv(input_csv_path, dtype=str)
129-
df_full.columns = df_full.columns.map(lambda x: x.strip().replace("\n", " "))
160+
df_full.columns = df_full.columns.map(lambda x: x.strip().replace("\n", " ").lower())
130161

131-
expected_cols = ['Question', 'Correct Answer', 'Answer options', 'Very Wrong Answer']
162+
expected_cols = ['question', 'correct answer', 'answer options', 'very wrong answer']
132163

133164
# Drop rows where essential information for questions or options is missing
134165
# or where ID combo is a placeholder like '#REF!'
135166
# Note: Very Wrong Answer can be missing, so we don't include it in the required columns
136-
required_cols = ['Question', 'Correct Answer', 'Answer options']
167+
required_cols = ['question', 'correct answer', 'answer options']
137168
df_input = df_full.dropna(subset=required_cols)[expected_cols]
138-
if not df_input[df_input.duplicated(subset=["Question"])].empty:
139-
print(df_input[df_input.duplicated(subset=["Question"])])
169+
if not df_input[df_input.duplicated(subset=["question"])].empty:
170+
print(df_input[df_input.duplicated(subset=["question"])])
140171
raise ValueError("Duplicated question")
141172

142173
total_questions = len(df_input)
@@ -148,18 +179,18 @@ def main():
148179
processed_question_ids = set()
149180

150181
for index, row in df_input.iterrows():
151-
question_id_raw = str(index)
182+
question_id_raw = str(index + start_index)
152183

153184
# Safely extract question text, handling cases where it might be a Series
154-
raw_question_val = row['Question']
185+
raw_question_val = row['question']
155186
if isinstance(raw_question_val, pd.Series):
156187
actual_question_string = raw_question_val.iloc[0]
157188
else:
158189
actual_question_string = raw_question_val
159-
question_text = str(actual_question_string).strip().title()
190+
question_text = str(actual_question_string).strip()
160191

161-
correct_answer_text_raw = str(row.get('Correct Answer', '')).strip().title()
162-
answer_options_str = str(row.get('Answer options', '')).strip()
192+
correct_answer_text_raw = str(row.get('correct answer', '')).strip()
193+
answer_options_str = str(row.get('answer options', '')).strip()
163194

164195
if not question_id_raw or not question_text:
165196
print(f"Skipping row {index} due to missing ID combo or Question text.")
@@ -270,7 +301,7 @@ def main():
270301
new_option_letter_map = {i: chr(65 + i) for i in range(len(selected_options))}
271302

272303
# Get the very wrong answer for this question, if available
273-
raw_vwa_val = row.get('Very Wrong Answer')
304+
raw_vwa_val = row.get('very wrong answer')
274305
very_wrong_answer = str(raw_vwa_val).strip() if pd.notna(raw_vwa_val) else ''
275306

276307
# Create a map of options and their correctness values

0 commit comments

Comments
 (0)