Skip to content

Commit 0ab9f8c

Browse files
authored
Merge pull request #87 from codelion/feat-update-router-training
Feat update router training
2 parents 0378fd4 + c5a4e95 commit 0ab9f8c

File tree

5 files changed

+183
-9
lines changed

5 files changed

+183
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ or your own code where you want to use the results from optillm. You can use it
180180

181181
| Plugin | Slug | Description |
182182
| ----------------------- | ------------------ | ---------------------------------------------------------------------------------------------- |
183+
| Router | `router` | Uses the [optillm-bert-uncased](https://huggingface.co/codelion/optillm-bert-uncased) model to route requests to different approaches based on the user prompt |
183184
| Memory | `memory` | Implements a short term memory layer, enables you to use unbounded context length with any LLM |
184185
| Privacy | `privacy` | Anonymize PII data in request and deanonymize it back to original value in response |
185186
| Read URLs | `readurls` | Reads all URLs found in the request, fetches the content at the URL and adds it to the context |

optillm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
import re
1313
from concurrent.futures import ThreadPoolExecutor
1414

15-
# Import the LiteLLM wrapper
16-
from optillm.litellm_wrapper import LiteLLMWrapper
17-
1815
# Import approach modules
1916
from optillm.mcts import chat_with_mcts
2017
from optillm.bon import best_of_n_sampling
@@ -74,6 +71,8 @@ def get_config():
7471
azure_ad_token_provider=token_provider
7572
)
7673
else:
74+
# Import the LiteLLM wrapper
75+
from optillm.litellm_wrapper import LiteLLMWrapper
7776
default_client = LiteLLMWrapper()
7877
return default_client, API_KEY
7978

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import os
2+
import json
3+
import argparse
4+
import asyncio
5+
from tqdm import tqdm
6+
from datasets import load_dataset
7+
from openai import AsyncOpenAI
8+
from typing import List, Dict, Any, Tuple
9+
import random
10+
11+
# OptILM approaches remain the same as in original script
12+
APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
13+
14+
# Dataset configurations
15+
DATASET_CONFIGS = [
16+
("MixEval", "free_form"),
17+
("MixEval", "multiple_choice"),
18+
("MixEval_Hard", "free_form"),
19+
("MixEval_Hard", "multiple_choice")
20+
]
21+
22+
def construct_prompt(sample: Dict[str, Any], split_type: str) -> str:
23+
"""Construct prompt based on split type."""
24+
context = sample.get("context", "")
25+
prompt = sample["prompt"]
26+
27+
if split_type == "multiple_choice":
28+
options = sample["options"]
29+
options_text = "\nOptions:\n" + "\n".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
30+
return f"Context: {context}\n\nQuestion: {prompt}{options_text}\n\nProvide the correct answer from the options above."
31+
else:
32+
return f"Context: {context}\n\nQuestion: {prompt}\n\nProvide your answer."
33+
34+
def is_correct_response(response: str, targets: List[str]) -> bool:
35+
"""Check if response matches any of the target answers."""
36+
response = response.strip().lower()
37+
return any(target.strip().lower() == response for target in targets)
38+
39+
async def generate_response(prompt: str, approach: str) -> Dict[str, Any]:
40+
"""Generate a response using the specified approach."""
41+
if approach == "none":
42+
client = AsyncOpenAI()
43+
response = await client.chat.completions.create(
44+
model="gpt-4o-mini",
45+
messages=[{"role": "user", "content": prompt}],
46+
)
47+
return {
48+
"content": response.choices[0].message.content,
49+
"tokens": response.usage.completion_tokens,
50+
}
51+
else:
52+
client = AsyncOpenAI(api_key="none", base_url="http://localhost:8000/v1")
53+
response = await client.chat.completions.create(
54+
model=f"{approach}-gpt-4o-mini",
55+
messages=[{"role": "user", "content": prompt}],
56+
)
57+
return {
58+
"content": response.choices[0].message.content,
59+
"tokens": response.usage.completion_tokens,
60+
}
61+
62+
def rank_responses(responses: List[Dict[str, Any]], targets: List[str]) -> List[int]:
63+
"""Rank responses based on correctness and token efficiency."""
64+
# Create tuples of (index, is_correct, tokens) for sorting
65+
ranked_data = []
66+
for i, response in enumerate(responses):
67+
is_correct = is_correct_response(response["content"], targets)
68+
ranked_data.append((i, is_correct, response["tokens"]))
69+
70+
# Sort by correctness (True first) and then by tokens (ascending)
71+
ranked_data.sort(key=lambda x: (-int(x[1]), x[2]))
72+
73+
# Extract indices for final ranking
74+
return [idx for idx, _, _ in ranked_data]
75+
76+
async def process_sample(sample: Dict[str, Any], split_type: str) -> Dict[str, Any]:
77+
"""Process a single sample from the dataset."""
78+
prompt = construct_prompt(sample, split_type)
79+
results = []
80+
81+
# Generate responses for each approach
82+
for approach in APPROACHES:
83+
response = await generate_response(prompt, approach)
84+
results.append({"approach": approach, **response})
85+
86+
# Rank the responses based on correctness and token efficiency
87+
rankings = rank_responses(results, sample["target"])
88+
89+
# Add rankings to results
90+
for rank, idx in enumerate(rankings):
91+
results[idx]["rank"] = rank
92+
93+
return {
94+
"prompt": prompt,
95+
"results": results,
96+
}
97+
98+
async def generate_dataset(num_samples: int, output_file: str):
99+
"""Generate the dataset and save it to a JSONL file."""
100+
with open(output_file, "w") as f:
101+
for config, split_type in DATASET_CONFIGS:
102+
print(f"Processing {config} - {split_type}")
103+
dataset = load_dataset("MixEval/MixEval", config, split=split_type)
104+
105+
# Calculate samples per configuration
106+
samples_per_config = max(1, num_samples // len(DATASET_CONFIGS))
107+
108+
for sample in tqdm(dataset.select(range(samples_per_config)),
109+
total=samples_per_config,
110+
desc=f"{config}-{split_type}"):
111+
try:
112+
result = await process_sample(sample, split_type)
113+
f.write(json.dumps(result) + "\n")
114+
except Exception as e:
115+
print(f"Error processing sample: {str(e)}")
116+
117+
def main():
118+
parser = argparse.ArgumentParser(description="Generate OptILM Ground Truth dataset")
119+
parser.add_argument("--num_samples", type=int, default=100,
120+
help="Total number of samples to process (divided among configurations)")
121+
parser.add_argument("--output_file", type=str,
122+
default="optillm_ground_truth_dataset.jsonl",
123+
help="Output file path")
124+
args = parser.parse_args()
125+
126+
asyncio.run(generate_dataset(args.num_samples, args.output_file))
127+
print(f"Dataset generated and saved to {args.output_file}")
128+
129+
if __name__ == "__main__":
130+
main()

scripts/train_optillm_classifier.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __getitem__(self, idx):
5757
}
5858

5959
def load_and_preprocess_data(tokenizer):
60-
dataset = load_dataset('json', data_files='optillm_dataset.jsonl')
60+
dataset = load_dataset('json', data_files='optillm_combined_dataset.jsonl')
6161

6262
data_items = []
6363

@@ -290,11 +290,54 @@ def main(args):
290290
best_model.eval()
291291

292292
test_prompts = [
293+
# Linear Programming (likely MCTS or Z3)
293294
"Maximize x + y subject to: x + 2y <= 10, x >= 0, y >= 0",
295+
# Graph Theory (likely MCTS or RTO)
294296
"Find the shortest path between nodes A and B in the given graph",
297+
# Recursive Problem (likely MOA or COT)
295298
"Solve the Tower of Hanoi problem with 4 disks",
299+
# Number Theory (likely NONE or Z3)
296300
"Determine if the given number is prime",
297-
"Find all possible combinations of coins that sum up to $1"
301+
# Combinatorics (likely MCTS or BON)
302+
"Find all possible combinations of coins that sum up to $1",
303+
# Symbolic Mathematics (likely Z3 or LEAP)
304+
"Solve the equation: 2x^3 - 5x^2 + 3x - 7 = 0",
305+
# Natural Language Processing (likely PVG or SELF_CONSISTENCY)
306+
"Summarize the main points of the given article in three sentences",
307+
# Computer Vision (likely RSTAR or PVG)
308+
"Describe the contents of the image, including any text present",
309+
# Game Theory (likely MCTS or BON)
310+
"Find the Nash equilibrium for the prisoner's dilemma game",
311+
# Constraint Satisfaction (likely Z3 or PLANSEARCH)
312+
"Solve the Sudoku puzzle given the following initial configuration",
313+
# Optimization (likely MCTS or RSTAR)
314+
"Find the optimal route for a salesperson visiting 10 cities",
315+
# Logical Reasoning (likely COT_REFLECTION or SELF_CONSISTENCY)
316+
"If all A are B, and some B are C, what can we conclude about A and C?",
317+
# Time Series Analysis (likely RSTAR or PVG)
318+
"Predict the stock price for the next week given the past year's data",
319+
# Robotics (likely MCTS or RTO)
320+
"Plan a path for a robot to navigate through a room with obstacles",
321+
# Natural Language Understanding (likely PVG or LEAP)
322+
"Identify the sentiment and main topics in the following customer review",
323+
# Theorem Proving (likely Z3 or COT_REFLECTION)
324+
"Prove that the square root of 2 is irrational",
325+
# Reinforcement Learning (likely MCTS or RSTAR)
326+
"Design a policy for an agent to maximize its score in a given game environment",
327+
# Information Retrieval (likely PVG or SELF_CONSISTENCY)
328+
"Find the most relevant documents in the corpus for the given query",
329+
# Cryptography (likely Z3 or LEAP)
330+
"Decrypt the following message encrypted with a simple substitution cipher",
331+
# Quantum Computing (likely NONE or Z3)
332+
"Simulate a quantum circuit with 3 qubits and measure the output",
333+
# Computer Graphics (likely RSTAR or PVG)
334+
"Generate a 3D model of a house based on the given floor plan",
335+
# Bioinformatics (likely Z3 or LEAP)
336+
"Find potential binding sites for a given protein sequence in a DNA strand",
337+
# Automated Reasoning (likely COT_REFLECTION or Z3)
338+
"Given a set of logical statements, determine if the conclusion follows",
339+
# Natural Language Generation (likely PVG or SELF_CONSISTENCY)
340+
"Write a short story in the style of Edgar Allan Poe about a haunted lighthouse"
298341
]
299342

300343
effort_levels = [0.0, 0.2, 0.5, 0.8, 1.0]
@@ -310,13 +353,14 @@ def main(args):
310353
parser = argparse.ArgumentParser(description="Train OptILM classifier")
311354
parser.add_argument("--model_name", type=str, default="google-bert/bert-large-uncased", help="Pretrained model name")
312355
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training")
313-
parser.add_argument("--learning_rate", type=float, default=1e-6, help="Learning rate")
314-
parser.add_argument("--num_epochs", type=int, default=10, help="Maximum number of training epochs")
356+
parser.add_argument("--learning_rate", type=float, default=5e-7, help="Learning rate")
357+
parser.add_argument("--num_epochs", type=int, default=20, help="Maximum number of training epochs")
315358
parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub")
316359
parser.add_argument("--hub_model_id", type=str, help="Model ID for Hugging Face Hub")
317360
parser.add_argument("--k_folds", type=int, default=5, help="Number of folds for cross-validation")
318361
parser.add_argument("--patience", type=int, default=3, help="Number of epochs to wait for improvement before early stopping")
319362
parser.add_argument("--clip_value", type=float, default=1.0, help="Gradient clipping value")
320363

321364
args = parser.parse_args()
322-
main(args)
365+
main(args)
366+

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="optillm",
5-
version="0.0.8",
5+
version="0.0.9",
66
packages=find_packages(),
77
py_modules=['optillm'],
88
package_data={

0 commit comments

Comments
 (0)