Skip to content

Commit e6d61b1

Browse files
committed
add support for cot_decoding and entropy_decoding in local inference server
1 parent 24b559d commit e6d61b1

File tree

1 file changed

+147
-59
lines changed

1 file changed

+147
-59
lines changed

optillm/inference.py

Lines changed: 147 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from functools import lru_cache
1616
import time
1717

18+
from optillm.cot_decoding import cot_decode
19+
from optillm.entropy_decoding import entropy_decode
20+
1821
# Configure logging
1922
logging.basicConfig(level=logging.INFO)
2023
logger = logging.getLogger(__name__)
@@ -1221,6 +1224,17 @@ def create(
12211224
logprobs: Optional[bool] = None,
12221225
top_logprobs: Optional[int] = None,
12231226
active_adapter: Optional[Dict[str, Any]] = None,
1227+
decoding: Optional[str] = None,
1228+
# CoT specific params
1229+
k: int = 10,
1230+
num_beams: int = 1,
1231+
length_penalty: float = 1.0,
1232+
no_repeat_ngram_size: int = 0,
1233+
early_stopping: bool = False,
1234+
aggregate_paths: bool = False,
1235+
# Entropy specific params
1236+
top_k: int = 27,
1237+
min_p: float = 0.03,
12241238
**kwargs
12251239
) -> ChatCompletion:
12261240
"""Create a chat completion with OpenAI-compatible parameters"""
@@ -1229,71 +1243,145 @@ def create(
12291243

12301244
pipeline = self.client.get_pipeline(model)
12311245

1232-
# Set active adapter if specified in extra_body
1246+
# Set active adapter if specified
12331247
if active_adapter is not None:
12341248
logger.info(f"Setting active adapter to: {active_adapter}")
12351249
pipeline.lora_manager.set_active_adapter(pipeline.current_model, active_adapter)
1236-
1237-
# Apply chat template to messages
1238-
prompt = pipeline.tokenizer.apply_chat_template(
1239-
messages,
1240-
tokenize=False,
1241-
add_generation_prompt=True
1242-
)
1243-
1244-
# Set generation parameters
1245-
generation_params = {
1246-
"temperature": temperature,
1247-
"top_p": top_p,
1248-
"num_return_sequences": n,
1249-
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
1250-
"presence_penalty": presence_penalty,
1251-
"frequency_penalty": frequency_penalty,
1252-
"stop_sequences": [stop] if isinstance(stop, str) else stop,
1253-
"seed": seed,
1254-
"logit_bias": logit_bias,
1255-
"logprobs": logprobs,
1256-
"top_logprobs": top_logprobs
1257-
}
12581250

1259-
# Generate responses - now handles logprobs
1260-
responses, token_counts, logprobs_results = pipeline.generate(
1261-
prompt,
1262-
generation_params=generation_params
1263-
)
1264-
1265-
# Calculate prompt tokens
1266-
prompt_tokens = len(pipeline.tokenizer.encode(prompt))
1267-
completion_tokens = sum(token_counts)
1268-
1269-
# Create OpenAI-compatible response format
1270-
response_dict = {
1271-
"id": f"chatcmpl-{int(time.time()*1000)}",
1272-
"object": "chat.completion",
1273-
"created": int(time.time()),
1274-
"model": model,
1275-
"choices": [
1276-
{
1277-
"index": idx,
1278-
"message": {
1279-
"role": "assistant",
1280-
"content": response,
1281-
**({"logprobs": logprob_result} if logprob_result else {})
1282-
},
1283-
"finish_reason": "stop"
1251+
responses = []
1252+
logprobs_results = []
1253+
prompt_tokens = 0
1254+
completion_tokens = 0
1255+
1256+
try:
1257+
# Handle specialized decoding approaches
1258+
if decoding:
1259+
logger.info(f"Using specialized decoding approach: {decoding}")
1260+
1261+
if decoding == "cot_decoding":
1262+
# Use directly available parameters for CoT
1263+
cot_params = {
1264+
"k": k,
1265+
"num_beams": num_beams,
1266+
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
1267+
"temperature": temperature,
1268+
"top_p": top_p,
1269+
"repetition_penalty": 1.0 + frequency_penalty,
1270+
"length_penalty": length_penalty,
1271+
"no_repeat_ngram_size": no_repeat_ngram_size,
1272+
"early_stopping": early_stopping,
1273+
"aggregate_paths": aggregate_paths,
1274+
}
1275+
1276+
result, confidence = cot_decode(
1277+
pipeline.current_model,
1278+
pipeline.tokenizer,
1279+
messages,
1280+
**cot_params
1281+
)
1282+
responses = [result]
1283+
logprobs_results = [{"confidence_score": confidence} if confidence is not None else None]
1284+
completion_tokens = len(pipeline.tokenizer.encode(result))
1285+
1286+
elif decoding == "entropy_decoding":
1287+
# Configure generator for entropy decoding
1288+
generator = None
1289+
if seed is not None:
1290+
generator = torch.Generator(device=pipeline.current_model.device)
1291+
generator.manual_seed(seed)
1292+
1293+
# Use directly available parameters for entropy decoding
1294+
1295+
entropy_params = {
1296+
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
1297+
"temperature": 0.666,
1298+
"top_p": 0.90,
1299+
"top_k": top_k,
1300+
"min_p": min_p,
1301+
"generator": generator
1302+
}
1303+
1304+
result = entropy_decode(
1305+
pipeline.current_model,
1306+
pipeline.tokenizer,
1307+
messages,
1308+
**entropy_params
1309+
)
1310+
responses = [result]
1311+
logprobs_results = [None]
1312+
completion_tokens = len(pipeline.tokenizer.encode(result))
1313+
1314+
else:
1315+
raise ValueError(f"Unknown specialized decoding approach: {decoding}")
1316+
1317+
# Calculate prompt tokens for specialized approaches
1318+
prompt_text = pipeline.tokenizer.apply_chat_template(messages, tokenize=False)
1319+
prompt_tokens = len(pipeline.tokenizer.encode(prompt_text))
1320+
1321+
else:
1322+
# Standard generation
1323+
prompt = pipeline.tokenizer.apply_chat_template(
1324+
messages,
1325+
tokenize=False,
1326+
add_generation_prompt=True
1327+
)
1328+
1329+
# Set generation parameters
1330+
generation_params = {
1331+
"temperature": temperature,
1332+
"top_p": top_p,
1333+
"num_return_sequences": n,
1334+
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
1335+
"presence_penalty": presence_penalty,
1336+
"frequency_penalty": frequency_penalty,
1337+
"stop_sequences": [stop] if isinstance(stop, str) else stop,
1338+
"seed": seed,
1339+
"logit_bias": logit_bias,
1340+
"logprobs": logprobs,
1341+
"top_logprobs": top_logprobs
1342+
}
1343+
1344+
# Generate responses
1345+
responses, token_counts, logprobs_results = pipeline.generate(
1346+
prompt,
1347+
generation_params=generation_params
1348+
)
1349+
1350+
prompt_tokens = len(pipeline.tokenizer.encode(prompt))
1351+
completion_tokens = sum(token_counts)
1352+
1353+
# Create OpenAI-compatible response format
1354+
response_dict = {
1355+
"id": f"chatcmpl-{int(time.time()*1000)}",
1356+
"object": "chat.completion",
1357+
"created": int(time.time()),
1358+
"model": model,
1359+
"choices": [
1360+
{
1361+
"index": idx,
1362+
"message": {
1363+
"role": "assistant",
1364+
"content": response,
1365+
**({"logprobs": logprob_result} if logprob_result else {})
1366+
},
1367+
"finish_reason": "stop"
1368+
}
1369+
for idx, (response, logprob_result) in enumerate(zip(responses, logprobs_results))
1370+
],
1371+
"usage": {
1372+
"prompt_tokens": prompt_tokens,
1373+
"completion_tokens": completion_tokens,
1374+
"total_tokens": completion_tokens + prompt_tokens
12841375
}
1285-
for idx, (response, logprob_result) in enumerate(zip(responses, logprobs_results))
1286-
],
1287-
"usage": {
1288-
"prompt_tokens": prompt_tokens,
1289-
"completion_tokens": completion_tokens,
1290-
"total_tokens": completion_tokens + prompt_tokens
12911376
}
1292-
}
1293-
1294-
self.client.clean_unused_pipelines()
1295-
logger.debug(f"Response : {response_dict}")
1296-
return ChatCompletion(response_dict)
1377+
1378+
self.client.clean_unused_pipelines()
1379+
logger.debug(f"Response : {response_dict}")
1380+
return ChatCompletion(response_dict)
1381+
1382+
except Exception as e:
1383+
logger.error(f"Error in chat completion: {str(e)}")
1384+
raise
12971385

12981386
class Models:
12991387
"""OpenAI-compatible models interface"""

0 commit comments

Comments
 (0)