@@ -41,6 +41,119 @@ class ModelConfig:
4141 enable_prompt_caching : bool = True
4242 dynamic_temperature : bool = True
4343
44+
45+ @dataclass
46+ class LogProbsResult :
47+ """Container for logprobs calculation results"""
48+ tokens : List [str ]
49+ token_logprobs : List [float ]
50+ top_logprobs : List [Dict [str , float ]]
51+ bytes_per_token : List [List [int ]]
52+
53+ class LogProbsCalculator :
54+ """Handles calculation of log probabilities for generated tokens"""
55+
56+ def __init__ (self , tokenizer , model ):
57+ self .tokenizer = tokenizer
58+ self .model = model
59+
60+ def _get_bytes_for_token (self , token : str ) -> List [int ]:
61+ """Get UTF-8 bytes for a token"""
62+ try :
63+ return list (token .encode ('utf-8' ))
64+ except UnicodeEncodeError :
65+ return []
66+
67+ def _get_top_alternatives (
68+ self ,
69+ logits : torch .Tensor ,
70+ actual_token_id : int ,
71+ num_alternatives : int
72+ ) -> Dict [str , float ]:
73+ """Calculate top alternative tokens and their logprobs"""
74+ probs = F .softmax (logits , dim = - 1 )
75+ logprobs = torch .log (probs )
76+
77+ # Get top tokens excluding the actual token
78+ top_values , top_indices = torch .topk (logprobs , k = num_alternatives + 1 )
79+
80+ alternatives = {}
81+ for value , idx in zip (top_values , top_indices ):
82+ token = self .tokenizer .decode ([idx ])
83+ if idx != actual_token_id : # Skip the actual token
84+ alternatives [token ] = value .item ()
85+ if len (alternatives ) >= num_alternatives :
86+ break
87+
88+ return alternatives
89+
90+ def calculate_logprobs (
91+ self ,
92+ input_ids : torch .Tensor ,
93+ generated_ids : torch .Tensor ,
94+ attention_mask : torch .Tensor ,
95+ num_alternatives : int = 5
96+ ) -> LogProbsResult :
97+ """Calculate log probabilities for a sequence of tokens"""
98+ self .model .eval ()
99+
100+ with torch .no_grad ():
101+ # Get model outputs for the entire sequence
102+ outputs = self .model (
103+ input_ids = input_ids ,
104+ attention_mask = attention_mask ,
105+ return_dict = True
106+ )
107+ logits = outputs .logits
108+
109+ # Calculate softmax and log probabilities
110+ probs = F .softmax (logits , dim = - 1 )
111+ logprobs = torch .log (probs )
112+
113+ # Process each position
114+ all_tokens = []
115+ all_token_logprobs = []
116+ all_top_logprobs = []
117+ all_bytes = []
118+
119+ sequence_length = generated_ids .shape [- 1 ]
120+
121+ for pos in range (sequence_length - 1 ): # -1 because we look at next token
122+ next_token_id = generated_ids [0 , pos + 1 ]
123+ current_logits = logits [0 , pos ]
124+
125+ # Get token and its logprob
126+ token = self .tokenizer .decode ([next_token_id ])
127+ token_logprob = logprobs [0 , pos , next_token_id ].item ()
128+
129+ # Get top alternative tokens
130+ top_logprobs = self ._get_top_alternatives (
131+ current_logits ,
132+ next_token_id ,
133+ num_alternatives
134+ )
135+
136+ # Get bytes for token
137+ token_bytes = self ._get_bytes_for_token (token )
138+
139+ all_tokens .append (token )
140+ all_token_logprobs .append (token_logprob )
141+ all_top_logprobs .append (top_logprobs )
142+ all_bytes .append (token_bytes )
143+
144+ # Add None for the last token
145+ all_tokens .append (self .tokenizer .decode ([generated_ids [0 , - 1 ]]))
146+ all_token_logprobs .append (None )
147+ all_top_logprobs .append (None )
148+ all_bytes .append (self ._get_bytes_for_token (all_tokens [- 1 ]))
149+
150+ return LogProbsResult (
151+ tokens = all_tokens ,
152+ token_logprobs = all_token_logprobs ,
153+ top_logprobs = all_top_logprobs ,
154+ bytes_per_token = all_bytes
155+ )
156+
44157class MemoryEfficientAttention (nn .Module ):
45158 """
46159 Memory-efficient attention using linear attention mechanism.
@@ -561,7 +674,7 @@ def generate(
561674 prompt : str ,
562675 generation_params : Optional [Dict [str , Any ]] = None
563676 ) -> Tuple [List [str ], List [int ]]:
564- """Generate multiple responses for a prompt when n > 1 """
677+ """Generate completions with optional logprobs """
565678
566679 # Tokenize input
567680 inputs = self .tokenizer (
@@ -570,7 +683,17 @@ def generate(
570683 truncation = True ,
571684 return_tensors = "pt"
572685 ).to (self .current_model .device )
573-
686+
687+ # Extract logprobs parameters
688+ calculate_logprobs = generation_params .get ("logprobs" , False )
689+ top_logprobs = generation_params .get ("top_logprobs" , 0 )
690+
691+ if top_logprobs and not calculate_logprobs :
692+ raise ValueError ("logprobs must be true when top_logprobs is specified" )
693+
694+ if top_logprobs and not (0 <= top_logprobs <= 20 ):
695+ raise ValueError ("top_logprobs must be between 0 and 20" )
696+
574697 # Configure generation parameters
575698 gen_config = {
576699 "max_new_tokens" : generation_params .get ("max_new_tokens" , 4096 ),
@@ -580,8 +703,11 @@ def generate(
580703 "num_return_sequences" : generation_params .get ("num_return_sequences" , 1 ),
581704 "pad_token_id" : self .tokenizer .pad_token_id ,
582705 "eos_token_id" : self .tokenizer .eos_token_id ,
706+ "return_dict_in_generate" : True ,
707+ "output_scores" : calculate_logprobs ,
583708 }
584-
709+
710+ # Add optional parameters
585711 if generation_params :
586712 if generation_params .get ("presence_penalty" , 0 ) != 0 :
587713 gen_config ["presence_penalty" ] = generation_params ["presence_penalty" ]
@@ -596,28 +722,55 @@ def generate(
596722 torch .manual_seed (generation_params ["seed" ])
597723 if torch .cuda .is_available ():
598724 torch .cuda .manual_seed (generation_params ["seed" ])
599-
725+
600726 # Generate responses
601727 with torch .amp .autocast ('cuda' , dtype = self .dtype ):
602728 with torch .no_grad ():
603729 outputs = self .current_model .generate (
604730 ** inputs ,
605731 ** gen_config
606732 )
607-
608- # Process outputs - now handling multiple sequences
733+
734+ generated_sequences = outputs . sequences
609735 input_length = inputs ['input_ids' ].shape [1 ]
736+
610737 responses = []
611738 token_counts = []
612-
613- # For each generated sequence
614- for output in outputs :
615- response_tokens = output [input_length :]
739+ logprobs_results = []
740+
741+ # Process each generated sequence
742+ for sequence in generated_sequences :
743+ response_tokens = sequence [input_length :]
616744 response_text = self .tokenizer .decode (response_tokens , skip_special_tokens = True )
617745 responses .append (response_text )
618746 token_counts .append (len (response_tokens ))
619-
620- return responses , token_counts
747+
748+ # Calculate logprobs if requested
749+ if calculate_logprobs :
750+ calculator = LogProbsCalculator (self .tokenizer , self .current_model )
751+ logprobs_result = calculator .calculate_logprobs (
752+ input_ids = sequence .unsqueeze (0 ),
753+ generated_ids = sequence .unsqueeze (0 ),
754+ attention_mask = torch .ones_like (sequence ).unsqueeze (0 ),
755+ num_alternatives = top_logprobs or 5
756+ )
757+ logprobs_results .append ({
758+ "content" : [{
759+ "token" : token ,
760+ "logprob" : logprob ,
761+ "bytes" : bytes_ ,
762+ "top_logprobs" : top_logprobs
763+ } for token , logprob , bytes_ , top_logprobs in zip (
764+ logprobs_result .tokens [input_length :],
765+ logprobs_result .token_logprobs [input_length :],
766+ logprobs_result .bytes_per_token [input_length :],
767+ logprobs_result .top_logprobs [input_length :]
768+ )]
769+ })
770+ else :
771+ logprobs_results .append (None )
772+
773+ return responses , token_counts , logprobs_results
621774
622775 def setup_efficient_attention (self ):
623776 """Replace standard attention with memory-efficient version"""
@@ -917,15 +1070,24 @@ def process_batch(
9171070 return all_responses , [0 ] * len (all_responses )
9181071
9191072class ChatCompletionMessage :
920- def __init__ (self , content : str , role : str = "assistant" ):
1073+ def __init__ (self , content : str , role : str = "assistant" , logprobs : Optional [ Dict ] = None ):
9211074 self .content = content
9221075 self .role = role
1076+ self .logprobs = logprobs
9231077
9241078class ChatCompletionChoice :
925- def __init__ (self , index : int , message : Dict [str , str ], finish_reason : str = "stop" ):
1079+ def __init__ (
1080+ self ,
1081+ index : int ,
1082+ message : Dict [str , Any ],
1083+ finish_reason : str = "stop" ,
1084+ logprobs : Optional [Dict ] = None
1085+ ):
9261086 self .index = index
9271087 self .message = ChatCompletionMessage (** message )
9281088 self .finish_reason = finish_reason
1089+ if logprobs :
1090+ self .message .logprobs = logprobs
9291091
9301092class ChatCompletionUsage :
9311093 def __init__ (self , prompt_tokens : int , completion_tokens : int , total_tokens : int ):
@@ -950,7 +1112,6 @@ def __init__(self, response_dict: Dict):
9501112 self .usage = ChatCompletionUsage (** response_dict ["usage" ])
9511113
9521114 def model_dump (self ) -> Dict :
953- """Convert back to dictionary format if needed"""
9541115 return {
9551116 "id" : self .id ,
9561117 "object" : self .object ,
@@ -960,6 +1121,10 @@ def model_dump(self) -> Dict:
9601121 {
9611122 "index" : choice .index ,
9621123 "message" : {
1124+ "role" : choice .message .role ,
1125+ "content" : choice .message .content ,
1126+ "logprobs" : choice .message .logprobs
1127+ } if choice .message .logprobs else {
9631128 "role" : choice .message .role ,
9641129 "content" : choice .message .content
9651130 },
@@ -973,6 +1138,7 @@ def model_dump(self) -> Dict:
9731138 "total_tokens" : self .usage .total_tokens
9741139 }
9751140 }
1141+
9761142class InferenceClient :
9771143 """OpenAI SDK Compatible client for local inference with dynamic model support"""
9781144
@@ -1034,6 +1200,8 @@ def create(
10341200 logit_bias : Optional [Dict [str , float ]] = None ,
10351201 user : Optional [str ] = None ,
10361202 seed : Optional [int ] = None ,
1203+ logprobs : Optional [bool ] = None ,
1204+ top_logprobs : Optional [int ] = None ,
10371205 ** kwargs
10381206 ) -> ChatCompletion :
10391207 """Create a chat completion with OpenAI-compatible parameters"""
@@ -1059,11 +1227,13 @@ def create(
10591227 "frequency_penalty" : frequency_penalty ,
10601228 "stop_sequences" : [stop ] if isinstance (stop , str ) else stop ,
10611229 "seed" : seed ,
1062- "logit_bias" : logit_bias
1230+ "logit_bias" : logit_bias ,
1231+ "logprobs" : logprobs ,
1232+ "top_logprobs" : top_logprobs
10631233 }
10641234
1065- # Generate responses - now returns list of responses and token counts
1066- responses , token_counts = pipeline .generate (
1235+ # Generate responses - now handles logprobs
1236+ responses , token_counts , logprobs_results = pipeline .generate (
10671237 prompt ,
10681238 generation_params = generation_params
10691239 )
@@ -1083,11 +1253,12 @@ def create(
10831253 "index" : idx ,
10841254 "message" : {
10851255 "role" : "assistant" ,
1086- "content" : response
1256+ "content" : response ,
1257+ ** ({"logprobs" : logprob_result } if logprob_result else {})
10871258 },
10881259 "finish_reason" : "stop"
10891260 }
1090- for idx , response in enumerate (responses )
1261+ for idx , ( response , logprob_result ) in enumerate (zip ( responses , logprobs_results ) )
10911262 ],
10921263 "usage" : {
10931264 "prompt_tokens" : prompt_tokens ,
@@ -1097,9 +1268,8 @@ def create(
10971268 }
10981269
10991270 self .client .clean_unused_pipelines ()
1100- # Return ChatCompletion object
11011271 return ChatCompletion (response_dict )
1102-
1272+
11031273 class Models :
11041274 """OpenAI-compatible models interface"""
11051275 def list (self ):
0 commit comments