Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ and then use the same in your OpenAI client. You can pass any HuggingFace model
with your HuggingFace key. We also support adding any number of LoRAs on top of the model by using the `+` separator.

E.g. The following code loads the base model `meta-llama/Llama-3.2-1B-Instruct` and then adds two LoRAs on top - `patched-codes/Llama-3.2-1B-FixVulns` and `patched-codes/Llama-3.2-1B-FastApply`.
You can specify which LoRA to use using the `active_adapter` param in `extra_args` field of OpenAI SDK client. By default we will load the last specified adapter.
You can specify which LoRA to use using the `active_adapter` param in `extra_body` field of OpenAI SDK client. By default we will load the last specified adapter.

```python
OPENAI_BASE_URL = "http://localhost:8000/v1"
Expand Down Expand Up @@ -748,4 +748,4 @@ If you use this library in your research, please cite:

<p align="center">
⭐ <a href="https://github.com/codelion/optillm">Star us on GitHub</a> if you find OptiLLM useful!
</p>
</p>
150 changes: 116 additions & 34 deletions optillm.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion optillm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

# Version information
__version__ = "0.1.28"
__version__ = "0.2.0"

# Get the path to the root optillm.py
spec = util.spec_from_file_location(
Expand Down
65 changes: 44 additions & 21 deletions optillm/bon.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import optillm
from optillm import conversation_logger

logger = logging.getLogger(__name__)

def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_id: str = None) -> str:
bon_completion_tokens = 0

messages = [{"role": "system", "content": system_prompt},
Expand All @@ -12,13 +14,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st

try:
# Try to generate n completions in a single API call using n parameter
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=4096,
n=n,
temperature=1
)
provider_request = {
"model": model,
"messages": messages,
"max_tokens": 4096,
"n": n,
"temperature": 1
}
response = client.chat.completions.create(**provider_request)

# Log provider call
if request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

completions = [choice.message.content for choice in response.choices]
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
bon_completion_tokens += response.usage.completion_tokens
Expand All @@ -30,12 +39,19 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
# Fallback: Generate completions one by one in a loop
for i in range(n):
try:
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=4096,
temperature=1
)
provider_request = {
"model": model,
"messages": messages,
"max_tokens": 4096,
"temperature": 1
}
response = client.chat.completions.create(**provider_request)

# Log provider call
if request_id:
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

completions.append(response.choices[0].message.content)
bon_completion_tokens += response.usage.completion_tokens
logger.debug(f"Generated completion {i+1}/{n}")
Expand All @@ -59,13 +75,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
rating_messages.append({"role": "assistant", "content": completion})
rating_messages.append({"role": "user", "content": "Rate the above response:"})

rating_response = client.chat.completions.create(
model=model,
messages=rating_messages,
max_tokens=256,
n=1,
temperature=0.1
)
provider_request = {
"model": model,
"messages": rating_messages,
"max_tokens": 256,
"n": 1,
"temperature": 0.1
}
rating_response = client.chat.completions.create(**provider_request)

# Log provider call
if request_id:
response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
conversation_logger.log_provider_call(request_id, provider_request, response_dict)

bon_completion_tokens += rating_response.usage.completion_tokens
try:
rating = float(rating_response.choices[0].message.content.strip())
Expand Down
Loading