Skip to content

Commit 2c66fc2

Browse files
committed
Add none approach
fix logprobs return
1 parent 60a84ef commit 2c66fc2

File tree

2 files changed

+91
-20
lines changed

2 files changed

+91
-20
lines changed

optillm.py

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import asyncio
1212
import re
1313
from concurrent.futures import ThreadPoolExecutor
14+
from typing import Tuple, Optional, Union, Dict, Any, List
1415

1516
# Import approach modules
1617
from optillm.mcts import chat_with_mcts
@@ -83,7 +84,7 @@ def get_config():
8384

8485
# Server configuration
8586
server_config = {
86-
'approach': 'bon',
87+
'approach': 'none',
8788
'mcts_simulations': 2,
8889
'mcts_exploration': 0.2,
8990
'mcts_depth': 1,
@@ -101,11 +102,52 @@ def get_config():
101102
}
102103

103104
# List of known approaches
104-
known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar",
105-
"cot_reflection", "plansearch", "leap", "re2"]
105+
known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency",
106+
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
106107

107108
plugin_approaches = {}
108109

110+
def none_approach(
111+
client: Any,
112+
model: str,
113+
original_messages: List[Dict[str, str]],
114+
**kwargs
115+
) -> Dict[str, Any]:
116+
"""
117+
Direct proxy approach that passes through all parameters to the underlying endpoint.
118+
119+
Args:
120+
system_prompt: System prompt text (unused)
121+
initial_query: Initial query/conversation (unused)
122+
client: OpenAI client instance
123+
model: Model identifier
124+
original_messages: Original messages from the request
125+
**kwargs: Additional parameters to pass through
126+
127+
Returns:
128+
Dict[str, Any]: Full OpenAI API response
129+
"""
130+
# Strip 'none-' prefix from model if present
131+
if model.startswith('none-'):
132+
model = model[5:]
133+
134+
try:
135+
# Make the direct completion call with original messages and parameters
136+
response = client.chat.completions.create(
137+
model=model,
138+
messages=original_messages,
139+
**kwargs
140+
)
141+
142+
# Convert to dict if it's not already
143+
if hasattr(response, 'model_dump'):
144+
return response.model_dump()
145+
return response
146+
147+
except Exception as e:
148+
logger.error(f"Error in none approach: {str(e)}")
149+
raise
150+
109151
def load_plugins():
110152
# Clear existing plugins first but modify the global dict in place
111153
plugin_approaches.clear()
@@ -163,7 +205,7 @@ def load_plugins():
163205

164206
def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict):
165207
if model == 'auto':
166-
return 'SINGLE', ['bon'], model
208+
return 'SINGLE', ['none'], model
167209

168210
parts = model.split('-')
169211
approaches = []
@@ -188,7 +230,7 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache
188230
model_parts.append(part)
189231

190232
if not approaches:
191-
approaches = ['bon']
233+
approaches = ['none']
192234
operation = 'SINGLE'
193235

194236
actual_model = '-'.join(model_parts)
@@ -197,8 +239,21 @@ def parse_combined_approach(model: str, known_approaches: list, plugin_approache
197239

198240
def execute_single_approach(approach, system_prompt, initial_query, client, model):
199241
if approach in known_approaches:
200-
# Execute known approaches
201-
if approach == 'mcts':
242+
if approach == 'none':
243+
# Extract kwargs from the request data
244+
kwargs = {}
245+
if hasattr(request, 'json'):
246+
data = request.get_json()
247+
messages = data.get('messages', [])
248+
# Copy all parameters except 'model' and 'messages'
249+
kwargs = {k: v for k, v in data.items()
250+
if k not in ['model', 'messages', 'optillm_approach']}
251+
response = none_approach(original_messages=messages, client=client, model=model, **kwargs)
252+
253+
# For none approach, we return the response and a token count of 0
254+
# since the full token count is already in the response
255+
return response, 0
256+
elif approach == 'mcts':
202257
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
203258
server_config['mcts_exploration'], server_config['mcts_depth'])
204259
elif approach == 'bon':
@@ -329,7 +384,6 @@ def proxy():
329384
bearer_token = ""
330385

331386
if auth_header and auth_header.startswith("Bearer "):
332-
# Extract the bearer token
333387
bearer_token = auth_header.split("Bearer ")[1].strip()
334388
logger.debug(f"Intercepted Bearer Token: {bearer_token}")
335389

@@ -365,22 +419,37 @@ def proxy():
365419
client = default_client
366420

367421
try:
422+
# Check if any of the approaches is 'none'
423+
contains_none = any(approach == 'none' for approach in approaches)
424+
425+
if operation == 'SINGLE' and approaches[0] == 'none':
426+
# For none approach, return the response directly
427+
result, _ = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
428+
logger.debug(f'Direct proxy response: {result}')
429+
return jsonify(result), 200
430+
431+
elif operation == 'AND' or operation == 'OR':
432+
if contains_none:
433+
raise ValueError("'none' approach cannot be combined with other approaches")
434+
435+
# Handle non-none approaches
368436
if operation == 'SINGLE':
369-
final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
437+
response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
370438
elif operation == 'AND':
371-
final_response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
439+
response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
372440
elif operation == 'OR':
373441
loop = asyncio.new_event_loop()
374442
asyncio.set_event_loop(loop)
375-
final_response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
443+
response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
376444
else:
377445
raise ValueError(f"Unknown operation: {operation}")
446+
378447
except Exception as e:
379448
logger.error(f"Error processing request: {str(e)}")
380449
return jsonify({"error": str(e)}), 500
381450

382451
if stream:
383-
return Response(generate_streaming_response(final_response, model), content_type='text/event-stream')
452+
return Response(generate_streaming_response(response, model), content_type='text/event-stream')
384453
else:
385454
response_data = {
386455
'model': model,
@@ -390,13 +459,13 @@ def proxy():
390459
}
391460
}
392461

393-
if isinstance(final_response, list):
394-
for index, response in enumerate(final_response):
462+
if isinstance(response, list):
463+
for index, resp in enumerate(response):
395464
response_data['choices'].append({
396465
'index': index,
397466
'message': {
398467
'role': 'assistant',
399-
'content': response,
468+
'content': resp,
400469
},
401470
'finish_reason': 'stop'
402471
})
@@ -405,13 +474,13 @@ def proxy():
405474
'index': 0,
406475
'message': {
407476
'role': 'assistant',
408-
'content': final_response,
477+
'content': response,
409478
},
410479
'finish_reason': 'stop'
411480
})
412481

413-
logger.debug(f'API response: {response_data}')
414-
return jsonify(response_data), 200
482+
logger.debug(f'API response: {response_data}')
483+
return jsonify(response_data), 200
415484

416485
@app.route('/v1/models', methods=['GET'])
417486
def proxy_models():

optillm/inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ class ModelConfig:
2828
quantization_bits: int = 4
2929
device_preference: Optional[str] = None
3030
# Default generation parameters
31-
max_new_tokens: int = 512
31+
max_new_tokens: int = 4096
3232
do_sample: bool = True
3333
top_p: float = 0.9
3434
top_k: int = 50
3535
temperature: float = 0.7
3636
num_return_sequences: int = 1
3737
repetition_penalty: float = 1.0
3838
pad_token_id: Optional[int] = None
39+
logprobs: bool = False
3940
# Advanced parameters
4041
use_memory_efficient_attention: bool = True
4142
enable_prompt_caching: bool = True
@@ -769,7 +770,7 @@ def generate(
769770
})
770771
else:
771772
logprobs_results.append(None)
772-
773+
logger.debug(f"Logprobs_results : {logprobs_results}")
773774
return responses, token_counts, logprobs_results
774775

775776
def setup_efficient_attention(self):
@@ -1268,6 +1269,7 @@ def create(
12681269
}
12691270

12701271
self.client.clean_unused_pipelines()
1272+
logger.debug(f"Response : {response_dict}")
12711273
return ChatCompletion(response_dict)
12721274

12731275
class Models:

0 commit comments

Comments
 (0)