Skip to content

Commit 41821aa

Browse files
authored
Merge pull request #62 from codelion/feat-router
Feat router
2 parents c0cb613 + 9afdadb commit 41821aa

File tree

3 files changed

+350
-101
lines changed

3 files changed

+350
-101
lines changed

optillm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def get_config():
4949
# OpenAI, Azure, or LiteLLM API configuration
5050
if os.environ.get("OPENAI_API_KEY"):
5151
API_KEY = os.environ.get("OPENAI_API_KEY")
52-
default_client = OpenAI(api_key=API_KEY)
52+
base_url = server_config['base_url']
53+
if base_url != "":
54+
default_client = OpenAI(api_key=API_KEY, base_url=base_url)
55+
else:
56+
default_client = OpenAI(api_key=API_KEY)
5357
elif os.environ.get("AZURE_OPENAI_API_KEY"):
5458
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
5559
API_VERSION = os.environ.get("AZURE_API_VERSION")

optillm/plugins/router_plugin.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from transformers import AutoModel, AutoTokenizer, AutoConfig
5+
from huggingface_hub import hf_hub_download
6+
from safetensors import safe_open
7+
from safetensors.torch import load_model
8+
from transformers import AutoTokenizer, AutoModel
9+
from optillm.mcts import chat_with_mcts
10+
from optillm.bon import best_of_n_sampling
11+
from optillm.moa import mixture_of_agents
12+
from optillm.rto import round_trip_optimization
13+
from optillm.self_consistency import advanced_self_consistency_approach
14+
from optillm.pvg import inference_time_pv_game
15+
from optillm.z3_solver import Z3SymPySolverSystem
16+
from optillm.rstar import RStar
17+
from optillm.cot_reflection import cot_reflection
18+
from optillm.plansearch import plansearch
19+
from optillm.leap import leap
20+
from optillm.reread import re2_approach
21+
22+
SLUG = "router"
23+
24+
# Constants
25+
MAX_LENGTH = 512
26+
APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
27+
MODEL_NAME = "codelion/optillm-bert-uncased"
28+
29+
class OptILMClassifier(nn.Module):
30+
def __init__(self, base_model, num_labels):
31+
super().__init__()
32+
self.base_model = base_model
33+
self.effort_encoder = nn.Sequential(
34+
nn.Linear(1, 64),
35+
nn.ReLU(),
36+
nn.Linear(64, 64),
37+
nn.ReLU()
38+
)
39+
self.classifier = nn.Linear(base_model.config.hidden_size + 64, num_labels)
40+
41+
def forward(self, input_ids, attention_mask, effort):
42+
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
43+
pooled_output = outputs.last_hidden_state[:, 0] # Shape: (batch_size, hidden_size)
44+
effort_encoded = self.effort_encoder(effort.unsqueeze(1)) # Shape: (batch_size, 64)
45+
combined_input = torch.cat((pooled_output, effort_encoded), dim=1)
46+
logits = self.classifier(combined_input)
47+
return logits
48+
49+
def load_optillm_model():
50+
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
51+
# Load the base model
52+
base_model = AutoModel.from_pretrained("google-bert/bert-large-uncased")
53+
# Create the OptILMClassifier
54+
model = OptILMClassifier(base_model, num_labels=len(APPROACHES))
55+
model.to(device)
56+
# Download the safetensors file
57+
safetensors_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
58+
# Load the state dict from the safetensors file
59+
load_model(model, safetensors_path)
60+
61+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
62+
return model, tokenizer, device
63+
64+
def preprocess_input(tokenizer, system_prompt, initial_query):
65+
combined_input = f"{system_prompt}\n\nUser: {initial_query}"
66+
encoding = tokenizer.encode_plus(
67+
combined_input,
68+
add_special_tokens=True,
69+
max_length=MAX_LENGTH,
70+
padding='max_length',
71+
truncation=True,
72+
return_attention_mask=True,
73+
return_tensors='pt'
74+
)
75+
return encoding['input_ids'], encoding['attention_mask']
76+
77+
def predict_approach(model, input_ids, attention_mask, device, effort=0.7):
78+
model.eval()
79+
with torch.no_grad():
80+
input_ids = input_ids.to(device)
81+
attention_mask = attention_mask.to(device)
82+
effort_tensor = torch.tensor([effort], dtype=torch.float).to(device)
83+
84+
logits = model(input_ids, attention_mask=attention_mask, effort=effort_tensor)
85+
probabilities = F.softmax(logits, dim=1)
86+
predicted_approach_index = torch.argmax(probabilities, dim=1).item()
87+
confidence = probabilities[0][predicted_approach_index].item()
88+
89+
return APPROACHES[predicted_approach_index], confidence
90+
91+
def run(system_prompt, initial_query, client, model, **kwargs):
92+
try:
93+
# Load the trained model
94+
router_model, tokenizer, device = load_optillm_model()
95+
96+
# Preprocess the input
97+
input_ids, attention_mask = preprocess_input(tokenizer, system_prompt, initial_query)
98+
99+
# Predict the best approach
100+
predicted_approach, _ = predict_approach(router_model, input_ids, attention_mask, device)
101+
102+
print(f"Router predicted approach: {predicted_approach}")
103+
104+
# Route to the appropriate approach or use the model directly
105+
if predicted_approach == "none":
106+
# Use the model directly without routing
107+
response = client.chat.completions.create(
108+
model=model,
109+
messages=[
110+
{"role": "system", "content": system_prompt},
111+
{"role": "user", "content": initial_query}
112+
]
113+
)
114+
return response.choices[0].message.content, response.usage.completion_tokens
115+
elif predicted_approach == "mcts":
116+
return chat_with_mcts(system_prompt, initial_query, client, model, **kwargs)
117+
elif predicted_approach == "bon":
118+
return best_of_n_sampling(system_prompt, initial_query, client, model, **kwargs)
119+
elif predicted_approach == "moa":
120+
return mixture_of_agents(system_prompt, initial_query, client, model)
121+
elif predicted_approach == "rto":
122+
return round_trip_optimization(system_prompt, initial_query, client, model)
123+
elif predicted_approach == "z3":
124+
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
125+
return z3_solver.process_query(initial_query)
126+
elif predicted_approach == "self_consistency":
127+
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
128+
elif predicted_approach == "pvg":
129+
return inference_time_pv_game(system_prompt, initial_query, client, model)
130+
elif predicted_approach == "rstar":
131+
rstar = RStar(system_prompt, client, model, **kwargs)
132+
return rstar.solve(initial_query)
133+
elif predicted_approach == "cot_reflection":
134+
return cot_reflection(system_prompt, initial_query, client, model, **kwargs)
135+
elif predicted_approach == "plansearch":
136+
return plansearch(system_prompt, initial_query, client, model, **kwargs)
137+
elif predicted_approach == "leap":
138+
return leap(system_prompt, initial_query, client, model)
139+
elif predicted_approach == "re2":
140+
return re2_approach(system_prompt, initial_query, client, model, **kwargs)
141+
else:
142+
raise ValueError(f"Unknown approach: {predicted_approach}")
143+
144+
except Exception as e:
145+
# Log the error and fall back to using the model directly
146+
print(f"Error in router plugin: {str(e)}. Falling back to direct model usage.")
147+
response = client.chat.completions.create(
148+
model=model,
149+
messages=[
150+
{"role": "system", "content": system_prompt},
151+
{"role": "user", "content": initial_query}
152+
]
153+
)
154+
return response.choices[0].message.content, response.usage.completion_tokens

0 commit comments

Comments
 (0)