Skip to content

Commit 60a84ef

Browse files
committed
return logprobs and toplogprobs
add support for logprobs
1 parent 23e48ef commit 60a84ef

File tree

1 file changed

+192
-22
lines changed

1 file changed

+192
-22
lines changed

optillm/inference.py

Lines changed: 192 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,119 @@ class ModelConfig:
4141
enable_prompt_caching: bool = True
4242
dynamic_temperature: bool = True
4343

44+
45+
@dataclass
46+
class LogProbsResult:
47+
"""Container for logprobs calculation results"""
48+
tokens: List[str]
49+
token_logprobs: List[float]
50+
top_logprobs: List[Dict[str, float]]
51+
bytes_per_token: List[List[int]]
52+
53+
class LogProbsCalculator:
54+
"""Handles calculation of log probabilities for generated tokens"""
55+
56+
def __init__(self, tokenizer, model):
57+
self.tokenizer = tokenizer
58+
self.model = model
59+
60+
def _get_bytes_for_token(self, token: str) -> List[int]:
61+
"""Get UTF-8 bytes for a token"""
62+
try:
63+
return list(token.encode('utf-8'))
64+
except UnicodeEncodeError:
65+
return []
66+
67+
def _get_top_alternatives(
68+
self,
69+
logits: torch.Tensor,
70+
actual_token_id: int,
71+
num_alternatives: int
72+
) -> Dict[str, float]:
73+
"""Calculate top alternative tokens and their logprobs"""
74+
probs = F.softmax(logits, dim=-1)
75+
logprobs = torch.log(probs)
76+
77+
# Get top tokens excluding the actual token
78+
top_values, top_indices = torch.topk(logprobs, k=num_alternatives + 1)
79+
80+
alternatives = {}
81+
for value, idx in zip(top_values, top_indices):
82+
token = self.tokenizer.decode([idx])
83+
if idx != actual_token_id: # Skip the actual token
84+
alternatives[token] = value.item()
85+
if len(alternatives) >= num_alternatives:
86+
break
87+
88+
return alternatives
89+
90+
def calculate_logprobs(
91+
self,
92+
input_ids: torch.Tensor,
93+
generated_ids: torch.Tensor,
94+
attention_mask: torch.Tensor,
95+
num_alternatives: int = 5
96+
) -> LogProbsResult:
97+
"""Calculate log probabilities for a sequence of tokens"""
98+
self.model.eval()
99+
100+
with torch.no_grad():
101+
# Get model outputs for the entire sequence
102+
outputs = self.model(
103+
input_ids=input_ids,
104+
attention_mask=attention_mask,
105+
return_dict=True
106+
)
107+
logits = outputs.logits
108+
109+
# Calculate softmax and log probabilities
110+
probs = F.softmax(logits, dim=-1)
111+
logprobs = torch.log(probs)
112+
113+
# Process each position
114+
all_tokens = []
115+
all_token_logprobs = []
116+
all_top_logprobs = []
117+
all_bytes = []
118+
119+
sequence_length = generated_ids.shape[-1]
120+
121+
for pos in range(sequence_length - 1): # -1 because we look at next token
122+
next_token_id = generated_ids[0, pos + 1]
123+
current_logits = logits[0, pos]
124+
125+
# Get token and its logprob
126+
token = self.tokenizer.decode([next_token_id])
127+
token_logprob = logprobs[0, pos, next_token_id].item()
128+
129+
# Get top alternative tokens
130+
top_logprobs = self._get_top_alternatives(
131+
current_logits,
132+
next_token_id,
133+
num_alternatives
134+
)
135+
136+
# Get bytes for token
137+
token_bytes = self._get_bytes_for_token(token)
138+
139+
all_tokens.append(token)
140+
all_token_logprobs.append(token_logprob)
141+
all_top_logprobs.append(top_logprobs)
142+
all_bytes.append(token_bytes)
143+
144+
# Add None for the last token
145+
all_tokens.append(self.tokenizer.decode([generated_ids[0, -1]]))
146+
all_token_logprobs.append(None)
147+
all_top_logprobs.append(None)
148+
all_bytes.append(self._get_bytes_for_token(all_tokens[-1]))
149+
150+
return LogProbsResult(
151+
tokens=all_tokens,
152+
token_logprobs=all_token_logprobs,
153+
top_logprobs=all_top_logprobs,
154+
bytes_per_token=all_bytes
155+
)
156+
44157
class MemoryEfficientAttention(nn.Module):
45158
"""
46159
Memory-efficient attention using linear attention mechanism.
@@ -561,7 +674,7 @@ def generate(
561674
prompt: str,
562675
generation_params: Optional[Dict[str, Any]] = None
563676
) -> Tuple[List[str], List[int]]:
564-
"""Generate multiple responses for a prompt when n > 1"""
677+
"""Generate completions with optional logprobs"""
565678

566679
# Tokenize input
567680
inputs = self.tokenizer(
@@ -570,7 +683,17 @@ def generate(
570683
truncation=True,
571684
return_tensors="pt"
572685
).to(self.current_model.device)
573-
686+
687+
# Extract logprobs parameters
688+
calculate_logprobs = generation_params.get("logprobs", False)
689+
top_logprobs = generation_params.get("top_logprobs", 0)
690+
691+
if top_logprobs and not calculate_logprobs:
692+
raise ValueError("logprobs must be true when top_logprobs is specified")
693+
694+
if top_logprobs and not (0 <= top_logprobs <= 20):
695+
raise ValueError("top_logprobs must be between 0 and 20")
696+
574697
# Configure generation parameters
575698
gen_config = {
576699
"max_new_tokens": generation_params.get("max_new_tokens", 4096),
@@ -580,8 +703,11 @@ def generate(
580703
"num_return_sequences": generation_params.get("num_return_sequences", 1),
581704
"pad_token_id": self.tokenizer.pad_token_id,
582705
"eos_token_id": self.tokenizer.eos_token_id,
706+
"return_dict_in_generate": True,
707+
"output_scores": calculate_logprobs,
583708
}
584-
709+
710+
# Add optional parameters
585711
if generation_params:
586712
if generation_params.get("presence_penalty", 0) != 0:
587713
gen_config["presence_penalty"] = generation_params["presence_penalty"]
@@ -596,28 +722,55 @@ def generate(
596722
torch.manual_seed(generation_params["seed"])
597723
if torch.cuda.is_available():
598724
torch.cuda.manual_seed(generation_params["seed"])
599-
725+
600726
# Generate responses
601727
with torch.amp.autocast('cuda', dtype=self.dtype):
602728
with torch.no_grad():
603729
outputs = self.current_model.generate(
604730
**inputs,
605731
**gen_config
606732
)
607-
608-
# Process outputs - now handling multiple sequences
733+
734+
generated_sequences = outputs.sequences
609735
input_length = inputs['input_ids'].shape[1]
736+
610737
responses = []
611738
token_counts = []
612-
613-
# For each generated sequence
614-
for output in outputs:
615-
response_tokens = output[input_length:]
739+
logprobs_results = []
740+
741+
# Process each generated sequence
742+
for sequence in generated_sequences:
743+
response_tokens = sequence[input_length:]
616744
response_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True)
617745
responses.append(response_text)
618746
token_counts.append(len(response_tokens))
619-
620-
return responses, token_counts
747+
748+
# Calculate logprobs if requested
749+
if calculate_logprobs:
750+
calculator = LogProbsCalculator(self.tokenizer, self.current_model)
751+
logprobs_result = calculator.calculate_logprobs(
752+
input_ids=sequence.unsqueeze(0),
753+
generated_ids=sequence.unsqueeze(0),
754+
attention_mask=torch.ones_like(sequence).unsqueeze(0),
755+
num_alternatives=top_logprobs or 5
756+
)
757+
logprobs_results.append({
758+
"content": [{
759+
"token": token,
760+
"logprob": logprob,
761+
"bytes": bytes_,
762+
"top_logprobs": top_logprobs
763+
} for token, logprob, bytes_, top_logprobs in zip(
764+
logprobs_result.tokens[input_length:],
765+
logprobs_result.token_logprobs[input_length:],
766+
logprobs_result.bytes_per_token[input_length:],
767+
logprobs_result.top_logprobs[input_length:]
768+
)]
769+
})
770+
else:
771+
logprobs_results.append(None)
772+
773+
return responses, token_counts, logprobs_results
621774

622775
def setup_efficient_attention(self):
623776
"""Replace standard attention with memory-efficient version"""
@@ -917,15 +1070,24 @@ def process_batch(
9171070
return all_responses, [0] * len(all_responses)
9181071

9191072
class ChatCompletionMessage:
920-
def __init__(self, content: str, role: str = "assistant"):
1073+
def __init__(self, content: str, role: str = "assistant", logprobs: Optional[Dict] = None):
9211074
self.content = content
9221075
self.role = role
1076+
self.logprobs = logprobs
9231077

9241078
class ChatCompletionChoice:
925-
def __init__(self, index: int, message: Dict[str, str], finish_reason: str = "stop"):
1079+
def __init__(
1080+
self,
1081+
index: int,
1082+
message: Dict[str, Any],
1083+
finish_reason: str = "stop",
1084+
logprobs: Optional[Dict] = None
1085+
):
9261086
self.index = index
9271087
self.message = ChatCompletionMessage(**message)
9281088
self.finish_reason = finish_reason
1089+
if logprobs:
1090+
self.message.logprobs = logprobs
9291091

9301092
class ChatCompletionUsage:
9311093
def __init__(self, prompt_tokens: int, completion_tokens: int, total_tokens: int):
@@ -950,7 +1112,6 @@ def __init__(self, response_dict: Dict):
9501112
self.usage = ChatCompletionUsage(**response_dict["usage"])
9511113

9521114
def model_dump(self) -> Dict:
953-
"""Convert back to dictionary format if needed"""
9541115
return {
9551116
"id": self.id,
9561117
"object": self.object,
@@ -960,6 +1121,10 @@ def model_dump(self) -> Dict:
9601121
{
9611122
"index": choice.index,
9621123
"message": {
1124+
"role": choice.message.role,
1125+
"content": choice.message.content,
1126+
"logprobs": choice.message.logprobs
1127+
} if choice.message.logprobs else {
9631128
"role": choice.message.role,
9641129
"content": choice.message.content
9651130
},
@@ -973,6 +1138,7 @@ def model_dump(self) -> Dict:
9731138
"total_tokens": self.usage.total_tokens
9741139
}
9751140
}
1141+
9761142
class InferenceClient:
9771143
"""OpenAI SDK Compatible client for local inference with dynamic model support"""
9781144

@@ -1034,6 +1200,8 @@ def create(
10341200
logit_bias: Optional[Dict[str, float]] = None,
10351201
user: Optional[str] = None,
10361202
seed: Optional[int] = None,
1203+
logprobs: Optional[bool] = None,
1204+
top_logprobs: Optional[int] = None,
10371205
**kwargs
10381206
) -> ChatCompletion:
10391207
"""Create a chat completion with OpenAI-compatible parameters"""
@@ -1059,11 +1227,13 @@ def create(
10591227
"frequency_penalty": frequency_penalty,
10601228
"stop_sequences": [stop] if isinstance(stop, str) else stop,
10611229
"seed": seed,
1062-
"logit_bias": logit_bias
1230+
"logit_bias": logit_bias,
1231+
"logprobs": logprobs,
1232+
"top_logprobs": top_logprobs
10631233
}
10641234

1065-
# Generate responses - now returns list of responses and token counts
1066-
responses, token_counts = pipeline.generate(
1235+
# Generate responses - now handles logprobs
1236+
responses, token_counts, logprobs_results = pipeline.generate(
10671237
prompt,
10681238
generation_params=generation_params
10691239
)
@@ -1083,11 +1253,12 @@ def create(
10831253
"index": idx,
10841254
"message": {
10851255
"role": "assistant",
1086-
"content": response
1256+
"content": response,
1257+
**({"logprobs": logprob_result} if logprob_result else {})
10871258
},
10881259
"finish_reason": "stop"
10891260
}
1090-
for idx, response in enumerate(responses)
1261+
for idx, (response, logprob_result) in enumerate(zip(responses, logprobs_results))
10911262
],
10921263
"usage": {
10931264
"prompt_tokens": prompt_tokens,
@@ -1097,9 +1268,8 @@ def create(
10971268
}
10981269

10991270
self.client.clean_unused_pipelines()
1100-
# Return ChatCompletion object
11011271
return ChatCompletion(response_dict)
1102-
1272+
11031273
class Models:
11041274
"""OpenAI-compatible models interface"""
11051275
def list(self):

0 commit comments

Comments
 (0)