22import logging
33import os
44import secrets
5+ import time
56from flask import Flask , request , jsonify
67from cerebras .cloud .sdk import Cerebras
78from openai import AzureOpenAI , OpenAI
3031from optillm .leap import leap
3132from optillm .reread import re2_approach
3233from optillm .cepo .cepo import cepo , CepoConfig , init_cepo_config
34+ from optillm .batching import RequestBatcher , BatchingError
3335
3436# Setup logging
3537logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
4648# Initialize Flask app
4749app = Flask (__name__ )
4850
51+ # Global request batcher (initialized in main() if batch mode enabled)
52+ request_batcher = None
53+
4954def get_config ():
5055 API_KEY = None
5156 if os .environ .get ("OPTILLM_API_KEY" ):
@@ -683,6 +688,31 @@ def proxy():
683688 client = default_client
684689
685690 try :
691+ # Route to batch processing if batch mode is enabled
692+ if request_batcher is not None :
693+ try :
694+ # Create request data for batching
695+ batch_request_data = {
696+ 'system_prompt' : system_prompt ,
697+ 'initial_query' : initial_query ,
698+ 'client' : client ,
699+ 'model' : model ,
700+ 'request_config' : request_config ,
701+ 'approaches' : approaches ,
702+ 'operation' : operation ,
703+ 'n' : n ,
704+ 'stream' : stream ,
705+ 'optillm_approach' : optillm_approach
706+ }
707+
708+ logger .debug ("Routing request to batch processor" )
709+ result = request_batcher .add_request (batch_request_data )
710+ return jsonify (result ), 200
711+
712+ except BatchingError as e :
713+ logger .error (f"Batch processing failed: { e } " )
714+ return jsonify ({"error" : str (e )}), 500
715+
686716 # Check if any of the approaches is 'none'
687717 contains_none = any (approach == 'none' for approach in approaches )
688718
@@ -849,6 +879,18 @@ def parse_args():
849879 # Use the function to get the default path
850880 default_config_path = get_config_path ()
851881
882+ # Batch mode arguments
883+ batch_mode_default = os .environ .get ("OPTILLM_BATCH_MODE" , "false" ).lower () == "true"
884+ batch_size_default = int (os .environ .get ("OPTILLM_BATCH_SIZE" , 4 ))
885+ batch_wait_ms_default = int (os .environ .get ("OPTILLM_BATCH_WAIT_MS" , 50 ))
886+
887+ parser .add_argument ("--batch-mode" , action = "store_true" , default = batch_mode_default ,
888+ help = "Enable automatic request batching (fail-fast, no fallback)" )
889+ parser .add_argument ("--batch-size" , type = int , default = batch_size_default ,
890+ help = "Maximum batch size for request batching" )
891+ parser .add_argument ("--batch-wait-ms" , dest = "batch_wait_ms" , type = int , default = batch_wait_ms_default ,
892+ help = "Maximum wait time in milliseconds for batch formation" )
893+
852894 # Special handling of all the CePO Configurations
853895 for field in fields (CepoConfig ):
854896 parser .add_argument (f"--cepo_{ field .name } " ,
@@ -877,6 +919,7 @@ def parse_args():
877919def main ():
878920 global server_config
879921 global cepo_config
922+ global request_batcher
880923 # Call this function at the start of main()
881924 args = parse_args ()
882925 # Update server_config with all argument values
@@ -885,6 +928,147 @@ def main():
885928 load_plugins ()
886929
887930 port = server_config ['port' ]
931+
932+ # Initialize request batcher if batch mode is enabled
933+ if server_config .get ('batch_mode' , False ):
934+ logger .info (f"Batch mode enabled: size={ server_config ['batch_size' ]} , "
935+ f"wait={ server_config ['batch_wait_ms' ]} ms" )
936+ request_batcher = RequestBatcher (
937+ max_batch_size = server_config ['batch_size' ],
938+ max_wait_ms = server_config ['batch_wait_ms' ],
939+ enable_logging = True
940+ )
941+
942+ # Set up the batch processor function
943+ def process_batch_requests (batch_requests ):
944+ """
945+ Process a batch of requests using true batching when possible
946+
947+ Args:
948+ batch_requests: List of request data dictionaries
949+
950+ Returns:
951+ List of response dictionaries
952+ """
953+ import time
954+ from optillm .batching import BatchingError
955+
956+ if not batch_requests :
957+ return []
958+
959+ logger .info (f"Processing batch of { len (batch_requests )} requests" )
960+
961+ # Check if we can use true batching (all requests compatible and using 'none' approach)
962+ can_use_true_batching = True
963+ first_req = batch_requests [0 ]
964+
965+ # Check compatibility across all requests
966+ for req_data in batch_requests :
967+ if (req_data ['stream' ] or
968+ req_data ['approaches' ] != first_req ['approaches' ] or
969+ req_data ['operation' ] != first_req ['operation' ] or
970+ req_data ['model' ] != first_req ['model' ]):
971+ can_use_true_batching = False
972+ break
973+
974+ # For now, implement sequential processing but with proper infrastructure
975+ # TODO: Implement true PyTorch/MLX batching in next phase
976+ responses = []
977+
978+ for i , req_data in enumerate (batch_requests ):
979+ try :
980+ logger .debug (f"Processing batch request { i + 1 } /{ len (batch_requests )} " )
981+
982+ # Extract request parameters
983+ system_prompt = req_data ['system_prompt' ]
984+ initial_query = req_data ['initial_query' ]
985+ client = req_data ['client' ]
986+ model = req_data ['model' ]
987+ request_config = req_data ['request_config' ]
988+ approaches = req_data ['approaches' ]
989+ operation = req_data ['operation' ]
990+ n = req_data ['n' ]
991+ stream = req_data ['stream' ]
992+
993+ # Validate request
994+ if stream :
995+ raise BatchingError ("Streaming requests cannot be batched" )
996+
997+ # Check if any of the approaches is 'none'
998+ contains_none = any (approach == 'none' for approach in approaches )
999+
1000+ if operation == 'SINGLE' and approaches [0 ] == 'none' :
1001+ # Pass through the request including the n parameter
1002+ result , completion_tokens = execute_single_approach (
1003+ approaches [0 ], system_prompt , initial_query , client , model , request_config )
1004+ elif operation == 'AND' or operation == 'OR' :
1005+ if contains_none :
1006+ raise ValueError ("'none' approach cannot be combined with other approaches" )
1007+ # Handle non-none approaches with n attempts
1008+ result , completion_tokens = execute_n_times (
1009+ n , approaches , operation , system_prompt , initial_query , client , model , request_config )
1010+ else :
1011+ # Handle non-none approaches with n attempts
1012+ result , completion_tokens = execute_n_times (
1013+ n , approaches , operation , system_prompt , initial_query , client , model , request_config )
1014+
1015+ # Convert tagged conversation to messages format if needed
1016+ if isinstance (result , list ):
1017+ processed_response = tagged_conversation_to_messages (result )
1018+ if processed_response != result : # Only process if format changed
1019+ result = [msg [- 1 ]['content' ] if isinstance (msg , list ) and msg else msg
1020+ for msg in processed_response ]
1021+ else :
1022+ messages = tagged_conversation_to_messages (result )
1023+ if isinstance (messages , list ) and messages : # Only process if format changed
1024+ result = messages [- 1 ]['content' ]
1025+
1026+ # Generate the response in OpenAI format
1027+ if isinstance (result , list ):
1028+ choices = []
1029+ for j , res in enumerate (result ):
1030+ choices .append ({
1031+ "index" : j ,
1032+ "message" : {
1033+ "role" : "assistant" ,
1034+ "content" : res
1035+ },
1036+ "finish_reason" : "stop"
1037+ })
1038+ else :
1039+ choices = [{
1040+ "index" : 0 ,
1041+ "message" : {
1042+ "role" : "assistant" ,
1043+ "content" : result
1044+ },
1045+ "finish_reason" : "stop"
1046+ }]
1047+
1048+ response_dict = {
1049+ "id" : f"chatcmpl-{ int (time .time ()* 1000 )} -{ i } " ,
1050+ "object" : "chat.completion" ,
1051+ "created" : int (time .time ()),
1052+ "model" : model ,
1053+ "choices" : choices ,
1054+ "usage" : {
1055+ "prompt_tokens" : 0 , # Will be calculated properly later
1056+ "completion_tokens" : completion_tokens if isinstance (completion_tokens , int ) else 0 ,
1057+ "total_tokens" : completion_tokens if isinstance (completion_tokens , int ) else 0
1058+ }
1059+ }
1060+
1061+ responses .append (response_dict )
1062+
1063+ except Exception as e :
1064+ logger .error (f"Error processing batch request { i + 1 } : { e } " )
1065+ raise BatchingError (f"Failed to process request { i + 1 } : { str (e )} " )
1066+
1067+ logger .info (f"Completed batch processing of { len (responses )} requests" )
1068+ return responses
1069+
1070+ # Set the processor function on the batcher
1071+ request_batcher .set_processor (process_batch_requests )
8881072
8891073 # Set logging level from user request
8901074 logging_level = server_config ['log' ]
0 commit comments