1515from functools import lru_cache
1616import time
1717
18+ from optillm .cot_decoding import cot_decode
19+ from optillm .entropy_decoding import entropy_decode
20+
1821# Configure logging
1922logging .basicConfig (level = logging .INFO )
2023logger = logging .getLogger (__name__ )
@@ -1221,6 +1224,17 @@ def create(
12211224 logprobs : Optional [bool ] = None ,
12221225 top_logprobs : Optional [int ] = None ,
12231226 active_adapter : Optional [Dict [str , Any ]] = None ,
1227+ decoding : Optional [str ] = None ,
1228+ # CoT specific params
1229+ k : int = 10 ,
1230+ num_beams : int = 1 ,
1231+ length_penalty : float = 1.0 ,
1232+ no_repeat_ngram_size : int = 0 ,
1233+ early_stopping : bool = False ,
1234+ aggregate_paths : bool = False ,
1235+ # Entropy specific params
1236+ top_k : int = 27 ,
1237+ min_p : float = 0.03 ,
12241238 ** kwargs
12251239 ) -> ChatCompletion :
12261240 """Create a chat completion with OpenAI-compatible parameters"""
@@ -1229,71 +1243,145 @@ def create(
12291243
12301244 pipeline = self .client .get_pipeline (model )
12311245
1232- # Set active adapter if specified in extra_body
1246+ # Set active adapter if specified
12331247 if active_adapter is not None :
12341248 logger .info (f"Setting active adapter to: { active_adapter } " )
12351249 pipeline .lora_manager .set_active_adapter (pipeline .current_model , active_adapter )
1236-
1237- # Apply chat template to messages
1238- prompt = pipeline .tokenizer .apply_chat_template (
1239- messages ,
1240- tokenize = False ,
1241- add_generation_prompt = True
1242- )
1243-
1244- # Set generation parameters
1245- generation_params = {
1246- "temperature" : temperature ,
1247- "top_p" : top_p ,
1248- "num_return_sequences" : n ,
1249- "max_new_tokens" : max_tokens if max_tokens is not None else 4096 ,
1250- "presence_penalty" : presence_penalty ,
1251- "frequency_penalty" : frequency_penalty ,
1252- "stop_sequences" : [stop ] if isinstance (stop , str ) else stop ,
1253- "seed" : seed ,
1254- "logit_bias" : logit_bias ,
1255- "logprobs" : logprobs ,
1256- "top_logprobs" : top_logprobs
1257- }
12581250
1259- # Generate responses - now handles logprobs
1260- responses , token_counts , logprobs_results = pipeline .generate (
1261- prompt ,
1262- generation_params = generation_params
1263- )
1264-
1265- # Calculate prompt tokens
1266- prompt_tokens = len (pipeline .tokenizer .encode (prompt ))
1267- completion_tokens = sum (token_counts )
1268-
1269- # Create OpenAI-compatible response format
1270- response_dict = {
1271- "id" : f"chatcmpl-{ int (time .time ()* 1000 )} " ,
1272- "object" : "chat.completion" ,
1273- "created" : int (time .time ()),
1274- "model" : model ,
1275- "choices" : [
1276- {
1277- "index" : idx ,
1278- "message" : {
1279- "role" : "assistant" ,
1280- "content" : response ,
1281- ** ({"logprobs" : logprob_result } if logprob_result else {})
1282- },
1283- "finish_reason" : "stop"
1251+ responses = []
1252+ logprobs_results = []
1253+ prompt_tokens = 0
1254+ completion_tokens = 0
1255+
1256+ try :
1257+ # Handle specialized decoding approaches
1258+ if decoding :
1259+ logger .info (f"Using specialized decoding approach: { decoding } " )
1260+
1261+ if decoding == "cot_decoding" :
1262+ # Use directly available parameters for CoT
1263+ cot_params = {
1264+ "k" : k ,
1265+ "num_beams" : num_beams ,
1266+ "max_new_tokens" : max_tokens if max_tokens is not None else 4096 ,
1267+ "temperature" : temperature ,
1268+ "top_p" : top_p ,
1269+ "repetition_penalty" : 1.0 + frequency_penalty ,
1270+ "length_penalty" : length_penalty ,
1271+ "no_repeat_ngram_size" : no_repeat_ngram_size ,
1272+ "early_stopping" : early_stopping ,
1273+ "aggregate_paths" : aggregate_paths ,
1274+ }
1275+
1276+ result , confidence = cot_decode (
1277+ pipeline .current_model ,
1278+ pipeline .tokenizer ,
1279+ messages ,
1280+ ** cot_params
1281+ )
1282+ responses = [result ]
1283+ logprobs_results = [{"confidence_score" : confidence } if confidence is not None else None ]
1284+ completion_tokens = len (pipeline .tokenizer .encode (result ))
1285+
1286+ elif decoding == "entropy_decoding" :
1287+ # Configure generator for entropy decoding
1288+ generator = None
1289+ if seed is not None :
1290+ generator = torch .Generator (device = pipeline .current_model .device )
1291+ generator .manual_seed (seed )
1292+
1293+ # Use directly available parameters for entropy decoding
1294+
1295+ entropy_params = {
1296+ "max_new_tokens" : max_tokens if max_tokens is not None else 4096 ,
1297+ "temperature" : 0.666 ,
1298+ "top_p" : 0.90 ,
1299+ "top_k" : top_k ,
1300+ "min_p" : min_p ,
1301+ "generator" : generator
1302+ }
1303+
1304+ result = entropy_decode (
1305+ pipeline .current_model ,
1306+ pipeline .tokenizer ,
1307+ messages ,
1308+ ** entropy_params
1309+ )
1310+ responses = [result ]
1311+ logprobs_results = [None ]
1312+ completion_tokens = len (pipeline .tokenizer .encode (result ))
1313+
1314+ else :
1315+ raise ValueError (f"Unknown specialized decoding approach: { decoding } " )
1316+
1317+ # Calculate prompt tokens for specialized approaches
1318+ prompt_text = pipeline .tokenizer .apply_chat_template (messages , tokenize = False )
1319+ prompt_tokens = len (pipeline .tokenizer .encode (prompt_text ))
1320+
1321+ else :
1322+ # Standard generation
1323+ prompt = pipeline .tokenizer .apply_chat_template (
1324+ messages ,
1325+ tokenize = False ,
1326+ add_generation_prompt = True
1327+ )
1328+
1329+ # Set generation parameters
1330+ generation_params = {
1331+ "temperature" : temperature ,
1332+ "top_p" : top_p ,
1333+ "num_return_sequences" : n ,
1334+ "max_new_tokens" : max_tokens if max_tokens is not None else 4096 ,
1335+ "presence_penalty" : presence_penalty ,
1336+ "frequency_penalty" : frequency_penalty ,
1337+ "stop_sequences" : [stop ] if isinstance (stop , str ) else stop ,
1338+ "seed" : seed ,
1339+ "logit_bias" : logit_bias ,
1340+ "logprobs" : logprobs ,
1341+ "top_logprobs" : top_logprobs
1342+ }
1343+
1344+ # Generate responses
1345+ responses , token_counts , logprobs_results = pipeline .generate (
1346+ prompt ,
1347+ generation_params = generation_params
1348+ )
1349+
1350+ prompt_tokens = len (pipeline .tokenizer .encode (prompt ))
1351+ completion_tokens = sum (token_counts )
1352+
1353+ # Create OpenAI-compatible response format
1354+ response_dict = {
1355+ "id" : f"chatcmpl-{ int (time .time ()* 1000 )} " ,
1356+ "object" : "chat.completion" ,
1357+ "created" : int (time .time ()),
1358+ "model" : model ,
1359+ "choices" : [
1360+ {
1361+ "index" : idx ,
1362+ "message" : {
1363+ "role" : "assistant" ,
1364+ "content" : response ,
1365+ ** ({"logprobs" : logprob_result } if logprob_result else {})
1366+ },
1367+ "finish_reason" : "stop"
1368+ }
1369+ for idx , (response , logprob_result ) in enumerate (zip (responses , logprobs_results ))
1370+ ],
1371+ "usage" : {
1372+ "prompt_tokens" : prompt_tokens ,
1373+ "completion_tokens" : completion_tokens ,
1374+ "total_tokens" : completion_tokens + prompt_tokens
12841375 }
1285- for idx , (response , logprob_result ) in enumerate (zip (responses , logprobs_results ))
1286- ],
1287- "usage" : {
1288- "prompt_tokens" : prompt_tokens ,
1289- "completion_tokens" : completion_tokens ,
1290- "total_tokens" : completion_tokens + prompt_tokens
12911376 }
1292- }
1293-
1294- self .client .clean_unused_pipelines ()
1295- logger .debug (f"Response : { response_dict } " )
1296- return ChatCompletion (response_dict )
1377+
1378+ self .client .clean_unused_pipelines ()
1379+ logger .debug (f"Response : { response_dict } " )
1380+ return ChatCompletion (response_dict )
1381+
1382+ except Exception as e :
1383+ logger .error (f"Error in chat completion: { str (e )} " )
1384+ raise
12971385
12981386 class Models :
12991387 """OpenAI-compatible models interface"""
0 commit comments