|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from transformers import AutoModel, AutoTokenizer, AutoConfig |
| 5 | +from huggingface_hub import hf_hub_download |
| 6 | +from safetensors import safe_open |
| 7 | +from safetensors.torch import load_model |
| 8 | +from transformers import AutoTokenizer, AutoModel |
| 9 | +from optillm.mcts import chat_with_mcts |
| 10 | +from optillm.bon import best_of_n_sampling |
| 11 | +from optillm.moa import mixture_of_agents |
| 12 | +from optillm.rto import round_trip_optimization |
| 13 | +from optillm.self_consistency import advanced_self_consistency_approach |
| 14 | +from optillm.pvg import inference_time_pv_game |
| 15 | +from optillm.z3_solver import Z3SymPySolverSystem |
| 16 | +from optillm.rstar import RStar |
| 17 | +from optillm.cot_reflection import cot_reflection |
| 18 | +from optillm.plansearch import plansearch |
| 19 | +from optillm.leap import leap |
| 20 | +from optillm.reread import re2_approach |
| 21 | + |
| 22 | +SLUG = "router" |
| 23 | + |
| 24 | +# Constants |
| 25 | +MAX_LENGTH = 512 |
| 26 | +APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"] |
| 27 | +MODEL_NAME = "codelion/optillm-bert-uncased" |
| 28 | + |
| 29 | +class OptILMClassifier(nn.Module): |
| 30 | + def __init__(self, base_model, num_labels): |
| 31 | + super().__init__() |
| 32 | + self.base_model = base_model |
| 33 | + self.effort_encoder = nn.Sequential( |
| 34 | + nn.Linear(1, 64), |
| 35 | + nn.ReLU(), |
| 36 | + nn.Linear(64, 64), |
| 37 | + nn.ReLU() |
| 38 | + ) |
| 39 | + self.classifier = nn.Linear(base_model.config.hidden_size + 64, num_labels) |
| 40 | + |
| 41 | + def forward(self, input_ids, attention_mask, effort): |
| 42 | + outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) |
| 43 | + pooled_output = outputs.last_hidden_state[:, 0] # Shape: (batch_size, hidden_size) |
| 44 | + effort_encoded = self.effort_encoder(effort.unsqueeze(1)) # Shape: (batch_size, 64) |
| 45 | + combined_input = torch.cat((pooled_output, effort_encoded), dim=1) |
| 46 | + logits = self.classifier(combined_input) |
| 47 | + return logits |
| 48 | + |
| 49 | +def load_optillm_model(): |
| 50 | + device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") |
| 51 | + # Load the base model |
| 52 | + base_model = AutoModel.from_pretrained("google-bert/bert-large-uncased") |
| 53 | + # Create the OptILMClassifier |
| 54 | + model = OptILMClassifier(base_model, num_labels=len(APPROACHES)) |
| 55 | + model.to(device) |
| 56 | + # Download the safetensors file |
| 57 | + safetensors_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors") |
| 58 | + # Load the state dict from the safetensors file |
| 59 | + load_model(model, safetensors_path) |
| 60 | + |
| 61 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| 62 | + return model, tokenizer, device |
| 63 | + |
| 64 | +def preprocess_input(tokenizer, system_prompt, initial_query): |
| 65 | + combined_input = f"{system_prompt}\n\nUser: {initial_query}" |
| 66 | + encoding = tokenizer.encode_plus( |
| 67 | + combined_input, |
| 68 | + add_special_tokens=True, |
| 69 | + max_length=MAX_LENGTH, |
| 70 | + padding='max_length', |
| 71 | + truncation=True, |
| 72 | + return_attention_mask=True, |
| 73 | + return_tensors='pt' |
| 74 | + ) |
| 75 | + return encoding['input_ids'], encoding['attention_mask'] |
| 76 | + |
| 77 | +def predict_approach(model, input_ids, attention_mask, device, effort=0.7): |
| 78 | + model.eval() |
| 79 | + with torch.no_grad(): |
| 80 | + input_ids = input_ids.to(device) |
| 81 | + attention_mask = attention_mask.to(device) |
| 82 | + effort_tensor = torch.tensor([effort], dtype=torch.float).to(device) |
| 83 | + |
| 84 | + logits = model(input_ids, attention_mask=attention_mask, effort=effort_tensor) |
| 85 | + probabilities = F.softmax(logits, dim=1) |
| 86 | + predicted_approach_index = torch.argmax(probabilities, dim=1).item() |
| 87 | + confidence = probabilities[0][predicted_approach_index].item() |
| 88 | + |
| 89 | + return APPROACHES[predicted_approach_index], confidence |
| 90 | + |
| 91 | +def run(system_prompt, initial_query, client, model, **kwargs): |
| 92 | + try: |
| 93 | + # Load the trained model |
| 94 | + router_model, tokenizer, device = load_optillm_model() |
| 95 | + |
| 96 | + # Preprocess the input |
| 97 | + input_ids, attention_mask = preprocess_input(tokenizer, system_prompt, initial_query) |
| 98 | + |
| 99 | + # Predict the best approach |
| 100 | + predicted_approach, _ = predict_approach(router_model, input_ids, attention_mask, device) |
| 101 | + |
| 102 | + print(f"Router predicted approach: {predicted_approach}") |
| 103 | + |
| 104 | + # Route to the appropriate approach or use the model directly |
| 105 | + if predicted_approach == "none": |
| 106 | + # Use the model directly without routing |
| 107 | + response = client.chat.completions.create( |
| 108 | + model=model, |
| 109 | + messages=[ |
| 110 | + {"role": "system", "content": system_prompt}, |
| 111 | + {"role": "user", "content": initial_query} |
| 112 | + ] |
| 113 | + ) |
| 114 | + return response.choices[0].message.content, response.usage.completion_tokens |
| 115 | + elif predicted_approach == "mcts": |
| 116 | + return chat_with_mcts(system_prompt, initial_query, client, model, **kwargs) |
| 117 | + elif predicted_approach == "bon": |
| 118 | + return best_of_n_sampling(system_prompt, initial_query, client, model, **kwargs) |
| 119 | + elif predicted_approach == "moa": |
| 120 | + return mixture_of_agents(system_prompt, initial_query, client, model) |
| 121 | + elif predicted_approach == "rto": |
| 122 | + return round_trip_optimization(system_prompt, initial_query, client, model) |
| 123 | + elif predicted_approach == "z3": |
| 124 | + z3_solver = Z3SymPySolverSystem(system_prompt, client, model) |
| 125 | + return z3_solver.process_query(initial_query) |
| 126 | + elif predicted_approach == "self_consistency": |
| 127 | + return advanced_self_consistency_approach(system_prompt, initial_query, client, model) |
| 128 | + elif predicted_approach == "pvg": |
| 129 | + return inference_time_pv_game(system_prompt, initial_query, client, model) |
| 130 | + elif predicted_approach == "rstar": |
| 131 | + rstar = RStar(system_prompt, client, model, **kwargs) |
| 132 | + return rstar.solve(initial_query) |
| 133 | + elif predicted_approach == "cot_reflection": |
| 134 | + return cot_reflection(system_prompt, initial_query, client, model, **kwargs) |
| 135 | + elif predicted_approach == "plansearch": |
| 136 | + return plansearch(system_prompt, initial_query, client, model, **kwargs) |
| 137 | + elif predicted_approach == "leap": |
| 138 | + return leap(system_prompt, initial_query, client, model) |
| 139 | + elif predicted_approach == "re2": |
| 140 | + return re2_approach(system_prompt, initial_query, client, model, **kwargs) |
| 141 | + else: |
| 142 | + raise ValueError(f"Unknown approach: {predicted_approach}") |
| 143 | + |
| 144 | + except Exception as e: |
| 145 | + # Log the error and fall back to using the model directly |
| 146 | + print(f"Error in router plugin: {str(e)}. Falling back to direct model usage.") |
| 147 | + response = client.chat.completions.create( |
| 148 | + model=model, |
| 149 | + messages=[ |
| 150 | + {"role": "system", "content": system_prompt}, |
| 151 | + {"role": "user", "content": initial_query} |
| 152 | + ] |
| 153 | + ) |
| 154 | + return response.choices[0].message.content, response.usage.completion_tokens |
0 commit comments