|
| 1 | +#!/usr/bin/env python3 |
| 2 | +import os |
| 3 | +import json |
| 4 | +import random |
| 5 | +from typing import List, Dict, Any |
| 6 | +import datasets |
| 7 | +from datasets import Dataset, DatasetDict |
| 8 | +from tqdm import tqdm |
| 9 | +import numpy as np |
| 10 | +from huggingface_hub import HfApi |
| 11 | + |
| 12 | +# Configure random seed for reproducibility |
| 13 | +random.seed(42) |
| 14 | +np.random.seed(42) |
| 15 | + |
| 16 | +# Configuration |
| 17 | +DATASET_NAME = "codelion/optillmbench" |
| 18 | +NUM_SAMPLES = 100 # Total samples in the benchmark |
| 19 | +SPLIT_RATIO = {"train": 0.8, "test": 0.2} # 80-20 split |
| 20 | +SOURCES = { |
| 21 | + "gsm8k": { |
| 22 | + "name": "gsm8k", |
| 23 | + "subset": "main", |
| 24 | + "samples": 25, |
| 25 | + "field_map": { |
| 26 | + "question": "question", |
| 27 | + "answer": "answer" |
| 28 | + } |
| 29 | + }, |
| 30 | + "boolq": { |
| 31 | + "name": "boolq", |
| 32 | + "subset": None, |
| 33 | + "samples": 25, |
| 34 | + "field_map": { |
| 35 | + "question": "question", |
| 36 | + "passage": "passage", |
| 37 | + "answer": "answer" |
| 38 | + } |
| 39 | + }, |
| 40 | + "mmlu_math": { |
| 41 | + "name": "cais/mmlu", |
| 42 | + "subset": "high_school_mathematics", # or "college_mathematics" |
| 43 | + "samples": 25, |
| 44 | + "field_map": { |
| 45 | + "question": "question", |
| 46 | + "choices": "choices", |
| 47 | + "answer": "answer" |
| 48 | + } |
| 49 | + }, |
| 50 | + "aqua_rat": { |
| 51 | + "name": "aqua_rat", |
| 52 | + "subset": None, |
| 53 | + "samples": 25, |
| 54 | + "field_map": { |
| 55 | + "question": "question", |
| 56 | + "answer": "correct" |
| 57 | + } |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +def select_challenging_examples( |
| 62 | + dataset: datasets.Dataset, |
| 63 | + category: str, |
| 64 | + num_samples: int, |
| 65 | + field_map: Dict[str, str] |
| 66 | +) -> List[Dict[str, Any]]: |
| 67 | + """Select challenging examples from the dataset""" |
| 68 | + examples = [] |
| 69 | + |
| 70 | + # Get all available examples |
| 71 | + all_examples = dataset["train"] if "train" in dataset else dataset["validation"] |
| 72 | + |
| 73 | + # Shuffle to randomize selection |
| 74 | + shuffled_indices = list(range(len(all_examples))) |
| 75 | + random.shuffle(shuffled_indices) |
| 76 | + |
| 77 | + # Select examples |
| 78 | + for idx in shuffled_indices: |
| 79 | + example = all_examples[idx] |
| 80 | + |
| 81 | + try: |
| 82 | + if category == "gsm8k": |
| 83 | + question = str(example[field_map["question"]]) |
| 84 | + answer = str(example[field_map["answer"]]) |
| 85 | + # Select only multi-step problems |
| 86 | + if answer.count("=") < 3: |
| 87 | + continue |
| 88 | + |
| 89 | + elif category == "boolq": |
| 90 | + passage = str(example[field_map["passage"]]) |
| 91 | + q = str(example[field_map["question"]]) |
| 92 | + question = f"Context: {passage}\nQuestion: {q}" |
| 93 | + answer = "Yes" if example[field_map["answer"]] else "No" |
| 94 | + |
| 95 | + elif category == "mmlu_math": |
| 96 | + question = str(example[field_map["question"]]) |
| 97 | + choices = example[field_map["choices"]] |
| 98 | + answer_index = int(example[field_map["answer"]]) # Convert answer to integer |
| 99 | + |
| 100 | + # Ensure answer index is within bounds |
| 101 | + if 0 <= answer_index < len(choices): |
| 102 | + answer = choices[answer_index] |
| 103 | + else: |
| 104 | + print(f"Warning: Answer index '{answer_index}' is out of range for choices: {choices}") |
| 105 | + continue # Skip this example if answer index is invalid |
| 106 | + |
| 107 | + # Format choices |
| 108 | + choices_text = "\n".join([f"{i}. {choice}" for i, choice in enumerate(choices)]) |
| 109 | + question = f"{question}\nChoices:\n{choices_text}" |
| 110 | + |
| 111 | + elif category == "aqua_rat": |
| 112 | + question = str(example[field_map["question"]]) |
| 113 | + answer = str(example[field_map["answer"]]) |
| 114 | + # Ensure non-trivial multiple-choice math problems |
| 115 | + if len(question.split()) < 12: |
| 116 | + continue |
| 117 | + |
| 118 | + # General filtering |
| 119 | + if len(question.split()) < 10: # Ensure substantial questions |
| 120 | + continue |
| 121 | + |
| 122 | + examples.append(format_question(category, question, answer)) |
| 123 | + |
| 124 | + if len(examples) >= num_samples: |
| 125 | + break |
| 126 | + |
| 127 | + except Exception as e: |
| 128 | + print(f"Error processing example from {category}: {str(e)}") |
| 129 | + continue |
| 130 | + |
| 131 | + return examples |
| 132 | + |
| 133 | + |
| 134 | +def clean_text(text: str) -> str: |
| 135 | + """Clean text by removing extra whitespace and normalizing newlines""" |
| 136 | + return " ".join(text.replace("\r", "\n").split()) |
| 137 | + |
| 138 | +def format_question(category: str, question: str, answer: str) -> Dict[str, Any]: |
| 139 | + """Format a question for the benchmark dataset""" |
| 140 | + # Basic sanity checks |
| 141 | + if not question or not answer: |
| 142 | + raise ValueError(f"Empty question or answer in {category}") |
| 143 | + |
| 144 | + return { |
| 145 | + "id": f"{category}_{random.getrandbits(32):08x}", |
| 146 | + "category": category, |
| 147 | + "question": clean_text(question), |
| 148 | + "answer": clean_text(answer), |
| 149 | + "metadata": { |
| 150 | + "source": SOURCES[category]["name"], |
| 151 | + "type": category, |
| 152 | + "difficulty": "challenging" # All examples are chosen to be challenging |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | +def load_source_dataset(config: Dict[str, Any]) -> datasets.Dataset: |
| 157 | + """Load a source dataset with error handling""" |
| 158 | + try: |
| 159 | + dataset = datasets.load_dataset( |
| 160 | + config["name"], |
| 161 | + config.get("subset") |
| 162 | + ) |
| 163 | + return dataset |
| 164 | + except Exception as e: |
| 165 | + print(f"Error loading dataset {config['name']}: {str(e)}") |
| 166 | + return None |
| 167 | + |
| 168 | +def create_benchmark_dataset() -> Dataset: |
| 169 | + """Create the complete benchmark dataset""" |
| 170 | + all_examples = [] |
| 171 | + |
| 172 | + # Process each source dataset |
| 173 | + for category, config in tqdm(SOURCES.items(), desc="Processing datasets"): |
| 174 | + print(f"\nProcessing {category} dataset...") |
| 175 | + |
| 176 | + # Load dataset |
| 177 | + dataset = load_source_dataset(config) |
| 178 | + if not dataset: |
| 179 | + continue |
| 180 | + |
| 181 | + # Select examples |
| 182 | + try: |
| 183 | + examples = select_challenging_examples( |
| 184 | + dataset, |
| 185 | + category, |
| 186 | + config["samples"], |
| 187 | + config["field_map"] |
| 188 | + ) |
| 189 | + print(f"Selected {len(examples)} examples from {category}") |
| 190 | + all_examples.extend(examples) |
| 191 | + except Exception as e: |
| 192 | + print(f"Error selecting examples from {category}: {str(e)}") |
| 193 | + continue |
| 194 | + |
| 195 | + # Shuffle final dataset |
| 196 | + random.shuffle(all_examples) |
| 197 | + |
| 198 | + # Create train/test splits |
| 199 | + num_train = int(len(all_examples) * SPLIT_RATIO["train"]) |
| 200 | + train_examples = all_examples[:num_train] |
| 201 | + test_examples = all_examples[num_train:] |
| 202 | + |
| 203 | + # Convert to HuggingFace Dataset |
| 204 | + dataset_dict = DatasetDict({ |
| 205 | + "train": Dataset.from_list(train_examples), |
| 206 | + "test": Dataset.from_list(test_examples) |
| 207 | + }) |
| 208 | + |
| 209 | + return dataset_dict |
| 210 | + |
| 211 | +def push_to_hub(dataset: DatasetDict, repo_id: str): |
| 212 | + """Push the dataset to HuggingFace Hub""" |
| 213 | + try: |
| 214 | + # Create README content |
| 215 | + readme_content = f"""# OptiLLMBench Dataset |
| 216 | +
|
| 217 | +A benchmark dataset for evaluating test-time optimization and scaling capabilities of language models. |
| 218 | +
|
| 219 | +## Dataset Description |
| 220 | +
|
| 221 | +OptiLLMBench contains {NUM_SAMPLES} carefully selected challenging problems across multiple domains: |
| 222 | +- Mathematical reasoning (from competition_math) |
| 223 | +- Code generation (from HumanEval) |
| 224 | +- Word problems (from GSM8K) |
| 225 | +- Multiple choice reasoning (from MMLU) |
| 226 | +- Logical deduction (from BBH) |
| 227 | +
|
| 228 | +Each example is chosen to benefit from test-time optimization techniques like: |
| 229 | +- Increased context length |
| 230 | +- Chain-of-thought reasoning |
| 231 | +- Self-consistency |
| 232 | +- Multiple solution attempts |
| 233 | +- And other scaling approaches |
| 234 | +
|
| 235 | +## Usage |
| 236 | +
|
| 237 | +```python |
| 238 | +from datasets import load_dataset |
| 239 | +
|
| 240 | +dataset = load_dataset("codelion/optillmbench") |
| 241 | +
|
| 242 | +# Access examples |
| 243 | +for example in dataset["train"]: |
| 244 | + print(f"Category: {{example['category']}}") |
| 245 | + print(f"Question: {{example['question']}}") |
| 246 | + print(f"Answer: {{example['answer']}}") |
| 247 | + print(f"Metadata: {{example['metadata']}}") |
| 248 | +``` |
| 249 | +
|
| 250 | +## Citation |
| 251 | +
|
| 252 | +If you use this dataset in your research, please cite: |
| 253 | +
|
| 254 | +```bibtex |
| 255 | +@software{{optillm, |
| 256 | + title = {{Optillm: Optimizing inference proxy for LLMs}}, |
| 257 | + author = {{Asankhaya Sharma}}, |
| 258 | + year = {{2024}}, |
| 259 | + publisher = {{GitHub}}, |
| 260 | + url = {{https://github.com/codelion/optillm}} |
| 261 | +}} |
| 262 | +``` |
| 263 | +""" |
| 264 | + |
| 265 | + # Push to hub |
| 266 | + dataset.push_to_hub( |
| 267 | + repo_id, |
| 268 | + private=False, |
| 269 | + embed_external_files=True |
| 270 | + ) |
| 271 | + |
| 272 | + # Update README |
| 273 | + api = HfApi() |
| 274 | + api.upload_file( |
| 275 | + path_or_fileobj=readme_content.encode(), |
| 276 | + path_in_repo="README.md", |
| 277 | + repo_id=repo_id, |
| 278 | + repo_type="dataset" |
| 279 | + ) |
| 280 | + |
| 281 | + print(f"Successfully pushed dataset to {repo_id}") |
| 282 | + |
| 283 | + except Exception as e: |
| 284 | + print(f"Error pushing to hub: {str(e)}") |
| 285 | + |
| 286 | +def main(): |
| 287 | + """Main execution function""" |
| 288 | + print("Starting OptILM Bench dataset generation...") |
| 289 | + |
| 290 | + # Create dataset |
| 291 | + dataset = create_benchmark_dataset() |
| 292 | + |
| 293 | + # Print statistics |
| 294 | + print("\nDataset Statistics:") |
| 295 | + for split in dataset: |
| 296 | + print(f"\n{split} split:") |
| 297 | + print(f"Number of examples: {len(dataset[split])}") |
| 298 | + categories = dataset[split].unique("category") |
| 299 | + for category in categories: |
| 300 | + count = len([ex for ex in dataset[split] if ex["category"] == category]) |
| 301 | + print(f"- {category}: {count} examples") |
| 302 | + |
| 303 | + # Push to HuggingFace Hub |
| 304 | + print("\nPushing dataset to HuggingFace Hub...") |
| 305 | + push_to_hub(dataset, DATASET_NAME) |
| 306 | + |
| 307 | +if __name__ == "__main__": |
| 308 | + main() |
0 commit comments