Skip to content

Commit 865da13

Browse files
committed
as
1 parent aad79ca commit 865da13

15 files changed

+1569
-382
lines changed

.github/workflows/test.yml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ jobs:
3737
3838
- name: Run unit tests
3939
run: |
40+
# Set up local inference environment
41+
export OPTILLM_API_KEY=optillm
42+
4043
# Run quick CI tests
4144
python tests/test_ci_quick.py
4245
@@ -45,6 +48,28 @@ jobs:
4548
4649
# Run approach tests
4750
python tests/test_approaches.py
51+
52+
# Run reasoning token tests (unit tests only - no MLX)
53+
python tests/test_reasoning_simple.py
54+
python tests/test_reasoning_tokens.py
55+
python tests/test_reasoning_integration.py
56+
57+
# Run batching tests (MLX tests auto-skipped on Ubuntu)
58+
python tests/test_batching.py
59+
60+
# Run JSON plugin tests
61+
python tests/test_json_plugin.py
62+
63+
# Run n-parameter tests
64+
python tests/test_n_parameter.py
65+
66+
# Run API compatibility tests with pytest if available
67+
python -m pytest tests/test_api_compatibility.py -v --tb=short || echo "API compatibility tests require pytest"
68+
69+
# Run main test framework with basic tests
70+
python tests/test.py --approaches none --single-test "Simple Math Problem"
71+
env:
72+
OPTILLM_API_KEY: optillm
4873

4974
integration-test:
5075
runs-on: ubuntu-latest

optillm.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import secrets
5+
import time
56
from flask import Flask, request, jsonify
67
from cerebras.cloud.sdk import Cerebras
78
from openai import AzureOpenAI, OpenAI
@@ -30,6 +31,7 @@
3031
from optillm.leap import leap
3132
from optillm.reread import re2_approach
3233
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
34+
from optillm.batching import RequestBatcher, BatchingError
3335

3436
# Setup logging
3537
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -46,6 +48,9 @@
4648
# Initialize Flask app
4749
app = Flask(__name__)
4850

51+
# Global request batcher (initialized in main() if batch mode enabled)
52+
request_batcher = None
53+
4954
def get_config():
5055
API_KEY = None
5156
if os.environ.get("OPTILLM_API_KEY"):
@@ -683,6 +688,31 @@ def proxy():
683688
client = default_client
684689

685690
try:
691+
# Route to batch processing if batch mode is enabled
692+
if request_batcher is not None:
693+
try:
694+
# Create request data for batching
695+
batch_request_data = {
696+
'system_prompt': system_prompt,
697+
'initial_query': initial_query,
698+
'client': client,
699+
'model': model,
700+
'request_config': request_config,
701+
'approaches': approaches,
702+
'operation': operation,
703+
'n': n,
704+
'stream': stream,
705+
'optillm_approach': optillm_approach
706+
}
707+
708+
logger.debug("Routing request to batch processor")
709+
result = request_batcher.add_request(batch_request_data)
710+
return jsonify(result), 200
711+
712+
except BatchingError as e:
713+
logger.error(f"Batch processing failed: {e}")
714+
return jsonify({"error": str(e)}), 500
715+
686716
# Check if any of the approaches is 'none'
687717
contains_none = any(approach == 'none' for approach in approaches)
688718

@@ -849,6 +879,18 @@ def parse_args():
849879
# Use the function to get the default path
850880
default_config_path = get_config_path()
851881

882+
# Batch mode arguments
883+
batch_mode_default = os.environ.get("OPTILLM_BATCH_MODE", "false").lower() == "true"
884+
batch_size_default = int(os.environ.get("OPTILLM_BATCH_SIZE", 4))
885+
batch_wait_ms_default = int(os.environ.get("OPTILLM_BATCH_WAIT_MS", 50))
886+
887+
parser.add_argument("--batch-mode", action="store_true", default=batch_mode_default,
888+
help="Enable automatic request batching (fail-fast, no fallback)")
889+
parser.add_argument("--batch-size", type=int, default=batch_size_default,
890+
help="Maximum batch size for request batching")
891+
parser.add_argument("--batch-wait-ms", dest="batch_wait_ms", type=int, default=batch_wait_ms_default,
892+
help="Maximum wait time in milliseconds for batch formation")
893+
852894
# Special handling of all the CePO Configurations
853895
for field in fields(CepoConfig):
854896
parser.add_argument(f"--cepo_{field.name}",
@@ -877,6 +919,7 @@ def parse_args():
877919
def main():
878920
global server_config
879921
global cepo_config
922+
global request_batcher
880923
# Call this function at the start of main()
881924
args = parse_args()
882925
# Update server_config with all argument values
@@ -885,6 +928,147 @@ def main():
885928
load_plugins()
886929

887930
port = server_config['port']
931+
932+
# Initialize request batcher if batch mode is enabled
933+
if server_config.get('batch_mode', False):
934+
logger.info(f"Batch mode enabled: size={server_config['batch_size']}, "
935+
f"wait={server_config['batch_wait_ms']}ms")
936+
request_batcher = RequestBatcher(
937+
max_batch_size=server_config['batch_size'],
938+
max_wait_ms=server_config['batch_wait_ms'],
939+
enable_logging=True
940+
)
941+
942+
# Set up the batch processor function
943+
def process_batch_requests(batch_requests):
944+
"""
945+
Process a batch of requests using true batching when possible
946+
947+
Args:
948+
batch_requests: List of request data dictionaries
949+
950+
Returns:
951+
List of response dictionaries
952+
"""
953+
import time
954+
from optillm.batching import BatchingError
955+
956+
if not batch_requests:
957+
return []
958+
959+
logger.info(f"Processing batch of {len(batch_requests)} requests")
960+
961+
# Check if we can use true batching (all requests compatible and using 'none' approach)
962+
can_use_true_batching = True
963+
first_req = batch_requests[0]
964+
965+
# Check compatibility across all requests
966+
for req_data in batch_requests:
967+
if (req_data['stream'] or
968+
req_data['approaches'] != first_req['approaches'] or
969+
req_data['operation'] != first_req['operation'] or
970+
req_data['model'] != first_req['model']):
971+
can_use_true_batching = False
972+
break
973+
974+
# For now, implement sequential processing but with proper infrastructure
975+
# TODO: Implement true PyTorch/MLX batching in next phase
976+
responses = []
977+
978+
for i, req_data in enumerate(batch_requests):
979+
try:
980+
logger.debug(f"Processing batch request {i+1}/{len(batch_requests)}")
981+
982+
# Extract request parameters
983+
system_prompt = req_data['system_prompt']
984+
initial_query = req_data['initial_query']
985+
client = req_data['client']
986+
model = req_data['model']
987+
request_config = req_data['request_config']
988+
approaches = req_data['approaches']
989+
operation = req_data['operation']
990+
n = req_data['n']
991+
stream = req_data['stream']
992+
993+
# Validate request
994+
if stream:
995+
raise BatchingError("Streaming requests cannot be batched")
996+
997+
# Check if any of the approaches is 'none'
998+
contains_none = any(approach == 'none' for approach in approaches)
999+
1000+
if operation == 'SINGLE' and approaches[0] == 'none':
1001+
# Pass through the request including the n parameter
1002+
result, completion_tokens = execute_single_approach(
1003+
approaches[0], system_prompt, initial_query, client, model, request_config)
1004+
elif operation == 'AND' or operation == 'OR':
1005+
if contains_none:
1006+
raise ValueError("'none' approach cannot be combined with other approaches")
1007+
# Handle non-none approaches with n attempts
1008+
result, completion_tokens = execute_n_times(
1009+
n, approaches, operation, system_prompt, initial_query, client, model, request_config)
1010+
else:
1011+
# Handle non-none approaches with n attempts
1012+
result, completion_tokens = execute_n_times(
1013+
n, approaches, operation, system_prompt, initial_query, client, model, request_config)
1014+
1015+
# Convert tagged conversation to messages format if needed
1016+
if isinstance(result, list):
1017+
processed_response = tagged_conversation_to_messages(result)
1018+
if processed_response != result: # Only process if format changed
1019+
result = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
1020+
for msg in processed_response]
1021+
else:
1022+
messages = tagged_conversation_to_messages(result)
1023+
if isinstance(messages, list) and messages: # Only process if format changed
1024+
result = messages[-1]['content']
1025+
1026+
# Generate the response in OpenAI format
1027+
if isinstance(result, list):
1028+
choices = []
1029+
for j, res in enumerate(result):
1030+
choices.append({
1031+
"index": j,
1032+
"message": {
1033+
"role": "assistant",
1034+
"content": res
1035+
},
1036+
"finish_reason": "stop"
1037+
})
1038+
else:
1039+
choices = [{
1040+
"index": 0,
1041+
"message": {
1042+
"role": "assistant",
1043+
"content": result
1044+
},
1045+
"finish_reason": "stop"
1046+
}]
1047+
1048+
response_dict = {
1049+
"id": f"chatcmpl-{int(time.time()*1000)}-{i}",
1050+
"object": "chat.completion",
1051+
"created": int(time.time()),
1052+
"model": model,
1053+
"choices": choices,
1054+
"usage": {
1055+
"prompt_tokens": 0, # Will be calculated properly later
1056+
"completion_tokens": completion_tokens if isinstance(completion_tokens, int) else 0,
1057+
"total_tokens": completion_tokens if isinstance(completion_tokens, int) else 0
1058+
}
1059+
}
1060+
1061+
responses.append(response_dict)
1062+
1063+
except Exception as e:
1064+
logger.error(f"Error processing batch request {i+1}: {e}")
1065+
raise BatchingError(f"Failed to process request {i+1}: {str(e)}")
1066+
1067+
logger.info(f"Completed batch processing of {len(responses)} requests")
1068+
return responses
1069+
1070+
# Set the processor function on the batcher
1071+
request_batcher.set_processor(process_batch_requests)
8881072

8891073
# Set logging level from user request
8901074
logging_level = server_config['log']

optillm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_config = module.get_config
2929
load_plugins = module.load_plugins
3030
count_reasoning_tokens = module.count_reasoning_tokens
31+
parse_args = module.parse_args
3132

3233
# Export execution functions
3334
execute_single_approach = module.execute_single_approach
@@ -50,6 +51,7 @@
5051
'get_config',
5152
'load_plugins',
5253
'count_reasoning_tokens',
54+
'parse_args',
5355
'execute_single_approach',
5456
'execute_combined_approaches',
5557
'execute_parallel_approaches',

0 commit comments

Comments
 (0)