@@ -69,17 +69,24 @@ def load_prompt_config(prompt_path):
6969def load_hf_dataset (config ):
7070 """Load HuggingFace dataset based on configuration."""
7171 dataset_name = config ['dataset_name' ]
72+ dataset_config = config .get ('dataset_config' , None )
7273 split = config .get ('split' , 'test' )
7374
7475 print (f"Loading dataset: { dataset_name } " )
7576
7677 try :
7778 # Try to load the specified split
78- dataset = load_dataset (dataset_name , split = split )
79+ if dataset_config :
80+ dataset = load_dataset (dataset_name , dataset_config , split = split )
81+ else :
82+ dataset = load_dataset (dataset_name , split = split )
7983 except :
8084 # Fallback to train split if test is not available
8185 print (f"Split '{ split } ' not found, falling back to 'train'" )
82- dataset = load_dataset (dataset_name , split = 'train' )
86+ if dataset_config :
87+ dataset = load_dataset (dataset_name , dataset_config , split = 'train' )
88+ else :
89+ dataset = load_dataset (dataset_name , split = 'train' )
8390
8491 print (f"Dataset loaded with { len (dataset )} examples" )
8592 return dataset
@@ -89,8 +96,10 @@ def evaluate_prompt(prompt, dataset, config, num_samples):
8996 input_field = config ['input_field' ]
9097 target_field = config ['target_field' ]
9198
92- # Check if this is emotion classification (0-5) or sentiment (0-1)
93- is_emotion = 'emotion' in config .get ('dataset_name' , '' ).lower ()
99+ # Check dataset type
100+ dataset_name = config .get ('dataset_name' , '' ).lower ()
101+ is_emotion = 'emotion' in dataset_name
102+ is_gsm8k = 'gsm8k' in dataset_name
94103
95104 # Sample from dataset
96105 samples = dataset .select (range (min (num_samples , len (dataset ))))
@@ -110,11 +119,14 @@ def evaluate_prompt(prompt, dataset, config, num_samples):
110119 # Call the LLM with retry logic
111120 for attempt in range (MAX_RETRIES ):
112121 try :
122+ # Adjust max_tokens based on task
123+ max_tokens = 500 if is_gsm8k else 20
124+
113125 response = test_model .chat .completions .create (
114126 model = TASK_MODEL_NAME ,
115127 messages = messages ,
116- temperature = 0.1 , # Low temperature for consistent classification
117- max_tokens = 20 # Allow slightly more tokens for emotion labels
128+ temperature = 0.1 , # Low temperature for consistent results
129+ max_tokens = max_tokens
118130 )
119131 break
120132 except Exception as e :
@@ -150,7 +162,41 @@ def evaluate_prompt(prompt, dataset, config, num_samples):
150162
151163 # Extract prediction from output
152164 try :
153- if is_emotion :
165+ if is_gsm8k :
166+ # For GSM8K, extract the numeric answer after ####
167+ # First, extract the expected answer from the ground truth
168+ expected_answer = expected .split ('####' )[- 1 ].strip ()
169+ try :
170+ expected_number = float (expected_answer .replace (',' , '' ))
171+ except :
172+ print (f"Warning: Could not parse expected answer: { expected_answer } " )
173+ total += 1
174+ continue
175+
176+ # Extract prediction from model output
177+ prediction = None
178+ if '####' in output_text :
179+ predicted_answer = output_text .split ('####' )[- 1 ].strip ()
180+ # Extract just the number, removing any extra text like $ signs
181+ import re
182+ numbers = re .findall (r'-?\$?[\d,]+\.?\d*' , predicted_answer )
183+ if numbers :
184+ try :
185+ # Remove $ and , from the number
186+ number_str = numbers [0 ].replace ('$' , '' ).replace (',' , '' )
187+ prediction = float (number_str )
188+ except :
189+ pass
190+
191+ # If we found a prediction, check if it matches
192+ if prediction is not None :
193+ # Check if answers match (with small tolerance for floats)
194+ if abs (prediction - expected_number ) < 0.001 :
195+ correct += 1
196+
197+ total += 1
198+
199+ elif is_emotion :
154200 # For emotion classification (0-5)
155201 numbers = re .findall (r'\b[0-5]\b' , output_text )
156202 if numbers :
0 commit comments