Skip to content

Commit 7a23694

Browse files
authored
Merge pull request #128 from codelion/feat-modernbert-router
Add new modernbert based router in plugin
2 parents 4c147e3 + 24329d4 commit 7a23694

File tree

8 files changed

+60
-25
lines changed

8 files changed

+60
-25
lines changed

optillm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,11 @@ def health():
624624

625625
def parse_args():
626626
parser = argparse.ArgumentParser(description="Run LLM inference with various approaches.")
627-
628-
# Add version argument using importlib.metadata
627+
629628
try:
630-
package_version = version('optillm')
631-
except Exception:
632-
package_version = "unknown" # Fallback if package is not installed
629+
from optillm import __version__ as package_version
630+
except ImportError:
631+
package_version = "unknown"
633632

634633
parser.add_argument('--version', action='version',
635634
version=f'%(prog)s {package_version}',

optillm/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
from importlib import util
22
import os
3+
import re
4+
5+
def get_version_from_setup():
6+
try:
7+
setup_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'setup.py')
8+
with open(setup_path, 'r') as f:
9+
content = f.read()
10+
version_match = re.search(r'version=["\']([^"\']+)["\']', content)
11+
if version_match:
12+
return version_match.group(1)
13+
except Exception:
14+
pass
15+
return "unknown"
316

417
# Get the path to the root optillm.py
518
spec = util.spec_from_file_location(
@@ -34,7 +47,7 @@
3447
generate_streaming_response = module.generate_streaming_response
3548

3649
# Version information
37-
__version__ = "0.0.8" # Match with setup.py
50+
__version__ = get_version_from_setup()
3851

3952
# List of exported symbols
4053
__all__ = [

optillm/plugins/readurls_plugin.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,28 @@
11
import re
22
from typing import Tuple, List
33
import requests
4+
import os
45
from bs4 import BeautifulSoup
56
from urllib.parse import urlparse
67

78
SLUG = "readurls"
89

10+
def get_version():
11+
try:
12+
# Get path to setup.py relative to this file
13+
current_dir = os.path.dirname(__file__)
14+
package_root = os.path.dirname(os.path.dirname(current_dir))
15+
setup_path = os.path.join(package_root, 'setup.py')
16+
17+
with open(setup_path, 'r') as f:
18+
content = f.read()
19+
version_match = re.search(r'version=["\']([^"\']+)["\']', content)
20+
if version_match:
21+
return version_match.group(1)
22+
except Exception:
23+
pass
24+
return "unknown"
25+
926
def extract_urls(text: str) -> List[str]:
1027
# Updated regex pattern to be more precise
1128
url_pattern = re.compile(r'https?://[^\s\'"]+')
@@ -24,8 +41,9 @@ def extract_urls(text: str) -> List[str]:
2441

2542
def fetch_webpage_content(url: str, max_length: int = 100000) -> str:
2643
try:
44+
version = get_version()
2745
headers = {
28-
'User-Agent': 'optillm/0.0.21 (https://github.com/codelion/optillm)'
46+
'User-Agent': f'optillm/{version} (https://github.com/codelion/optillm)'
2947
}
3048

3149
response = requests.get(url, headers=headers, timeout=10)

optillm/plugins/router_plugin.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
SLUG = "router"
2323

2424
# Constants
25-
MAX_LENGTH = 512
25+
MAX_LENGTH = 1024
2626
APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
27-
MODEL_NAME = "codelion/optillm-bert-uncased"
27+
BASE_MODEL = "answerdotai/ModernBERT-large"
28+
OPTILLM_MODEL_NAME = "codelion/optillm-modernbert-large"
2829

2930
class OptILMClassifier(nn.Module):
3031
def __init__(self, base_model, num_labels):
@@ -49,16 +50,16 @@ def forward(self, input_ids, attention_mask, effort):
4950
def load_optillm_model():
5051
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
5152
# Load the base model
52-
base_model = AutoModel.from_pretrained("google-bert/bert-large-uncased")
53+
base_model = AutoModel.from_pretrained(BASE_MODEL)
5354
# Create the OptILMClassifier
5455
model = OptILMClassifier(base_model, num_labels=len(APPROACHES))
5556
model.to(device)
5657
# Download the safetensors file
57-
safetensors_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
58+
safetensors_path = hf_hub_download(repo_id=OPTILLM_MODEL_NAME, filename="model.safetensors")
5859
# Load the state dict from the safetensors file
5960
load_model(model, safetensors_path)
6061

61-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
62+
tokenizer = AutoTokenizer.from_pretrained(OPTILLM_MODEL_NAME)
6263
return model, tokenizer, device
6364

6465
def preprocess_input(tokenizer, system_prompt, initial_query):

scripts/eval_aime_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
logger = logging.getLogger(__name__)
1616

1717
# Initialize OpenAI client
18-
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:8888/v1")
18+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:8000/v1")
1919

2020
SYSTEM_PROMPT = '''You are solving AIME (American Invitational Mathematics Examination) problems.
2121

scripts/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
datasets
22
accelerate
3-
huggingface_hub
3+
huggingface_hub
4+
git+https://github.com/huggingface/transformers.git

scripts/train_optillm_classifier.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
# Constants
1717
APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
18-
MAX_LENGTH = 512
18+
MAX_LENGTH = 1024
1919

2020
# Device selection
2121
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
@@ -233,6 +233,18 @@ def inference(model, tokenizer, prompt, effort_levels):
233233
return results
234234

235235
def main(args):
236+
237+
if args.push_to_hub:
238+
base_model = AutoModel.from_pretrained(args.model_name)
239+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
240+
# best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES))
241+
# best_model.to(device)
242+
# load_model(best_model, "best_model.safetensors")
243+
# we just push the base model and then upload the safetensors file manually as OptILMClassifier class doesn't have a push_to_hub method.
244+
base_model.push_to_hub(args.hub_model_id)
245+
tokenizer.push_to_hub(args.hub_model_id)
246+
return
247+
236248
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
237249
dataset = load_and_preprocess_data(tokenizer)
238250

@@ -273,15 +285,6 @@ def main(args):
273285

274286
print(f"\nBest performing model was from fold {best_fold} with validation accuracy {best_val_accuracy:.4f}")
275287

276-
if args.push_to_hub:
277-
base_model = AutoModel.from_pretrained(args.model_name)
278-
# best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES))
279-
# best_model.to(device)
280-
# load_model(best_model, "best_model.safetensors")
281-
# we just push the base model and then upload the safetensors file manually as OptILMClassifier class doesn't have a push_to_hub method.
282-
base_model.push_to_hub(args.hub_model_id)
283-
tokenizer.push_to_hub(args.hub_model_id)
284-
285288
# Load the best model for inference
286289
base_model = AutoModel.from_pretrained(args.model_name)
287290
best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="optillm",
5-
version="0.0.22",
5+
version="0.0.23",
66
packages=find_packages(),
77
py_modules=['optillm'],
88
package_data={

0 commit comments

Comments
 (0)