55from graphgen .bases .datatypes import Token
66
77
8- # TODO: implement SGLangWrapper methods
98class SGLangWrapper (BaseLLMWrapper ):
109 """
1110 Async inference backend based on SGLang offline engine.
@@ -59,43 +58,39 @@ def _build_sampling_params(
5958 params ["logprobs" ] = topk
6059 return params
6160
62- def _prep_prompt (self , text : str , history : Optional [List [str ]] = None ) -> str :
61+ def _prep_prompt (self , text : str , history : Optional [List [dict ]] = None ) -> str :
6362 """Convert raw text (+ optional history) into a single prompt string."""
6463 parts = []
6564 if self .system_prompt :
6665 parts .append (self .system_prompt )
6766 if history :
6867 assert len (history ) % 2 == 0 , "History must have even length (u/a turns)."
69- parts .extend (history )
68+ parts .extend ([ item [ "content" ] for item in history ] )
7069 parts .append (text )
7170 return "\n " .join (parts )
7271
7372 def _tokens_from_output (self , output : Dict [str , Any ]) -> List [Token ]:
74- """
75- Convert SGLang logprobs output into List[Token].
76- SGLang returns:
77- output['logprobs'][t] -> {
78- "token": <str>,
79- "logprob": <float>,
80- "top_k_tokens": [...],
81- "top_k_logprobs": [...],
82- }
83- """
8473 tokens : List [Token ] = []
85- if "logprobs" not in output or not output ["logprobs" ]:
86- return tokens
8774
88- for entry in output ["logprobs" ]:
89- token_str = entry ["token" ]
90- logprob = entry ["logprob" ]
91- prob = math .exp (logprob )
75+ meta = output .get ("meta_info" , {})
76+ logprobs = meta .get ("output_token_logprobs" , [])
77+ topks = meta .get ("output_top_logprobs" , [])
78+
79+ tokenizer = self .engine .tokenizer_manager .tokenizer
80+
81+ for idx , (lp , tid , _ ) in enumerate (logprobs ):
82+ prob = math .exp (lp )
83+ tok_str = tokenizer .decode ([tid ])
9284
9385 top_candidates = []
94- if self .topk > 0 and "top_k_tokens" in entry :
95- for tok , lp in zip (entry ["top_k_tokens" ], entry ["top_k_logprobs" ]):
96- top_candidates .append (Token (tok , math .exp (lp )))
86+ if self .topk > 0 and idx < len (topks ):
87+ for t_lp , t_tid , _ in topks [idx ][: self .topk ]:
88+ top_candidates .append (
89+ Token (text = tokenizer .decode ([t_tid ]), prob = math .exp (t_lp ))
90+ )
91+
92+ tokens .append (Token (text = tok_str , prob = prob , top_candidates = top_candidates ))
9793
98- tokens .append (Token (token_str , prob , top_candidates = top_candidates ))
9994 return tokens
10095
10196 async def generate_answer (
@@ -112,7 +107,7 @@ async def generate_answer(
112107 topk = 0 , # no logprobs needed for simple generation
113108 )
114109
115- outputs = self .engine .generate ([prompt ], sampling_params )
110+ outputs = await self .engine .async_generate ([prompt ], sampling_params )
116111 return self .filter_think_tags (outputs [0 ]["text" ])
117112
118113 async def generate_topk_per_token (
@@ -125,45 +120,23 @@ async def generate_topk_per_token(
125120 sampling_params = self ._build_sampling_params (
126121 temperature = self .temperature ,
127122 top_p = self .top_p ,
128- max_tokens = 5 , # keep short for token-level analysis
123+ max_tokens = 1 , # keep short for token-level analysis
129124 topk = self .topk ,
130- logprobs = True ,
131125 )
132126
133- outputs = self .engine .generate ([prompt ], sampling_params )
127+ outputs = await self .engine .async_generate (
128+ [prompt ], sampling_params , return_logprob = True , top_logprobs_num = 5
129+ )
130+ print (outputs )
134131 return self ._tokens_from_output (outputs [0 ])
135132
136133 async def generate_inputs_prob (
137134 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
138135 ) -> List [Token ]:
139- """
140- Return per-token probabilities for the *input* sequence.
141- SGLang offline engine does not expose this directly; we emulate by
142- generating 0 new tokens with logprobs enabled (returns prompt logprobs).
143- """
144- prompt = self ._prep_prompt (text , history )
145- sampling_params = self ._build_sampling_params (
146- temperature = 0.0 , # deterministic
147- top_p = 1.0 ,
148- max_tokens = 0 , # generate nothing
149- topk = self .topk ,
150- logprobs = True ,
136+ raise NotImplementedError (
137+ "SGLangWrapper does not support per-token logprobs yet."
151138 )
152139
153- outputs = self .engine .generate ([prompt ], sampling_params )
154- # SGLang returns prompt logprobs under key 'prompt_logprobs' when max_new_tokens=0
155- prompt_logprobs = outputs [0 ].get ("prompt_logprobs" , [])
156- tokens : List [Token ] = []
157- for entry in prompt_logprobs :
158- tokens .append (
159- Token (
160- text = entry ["token" ],
161- prob = math .exp (entry ["logprob" ]),
162- top_candidates = [], # SGLang does not give top-k for prompt tokens
163- )
164- )
165- return tokens
166-
167140 def shutdown (self ) -> None :
168141 """Gracefully shutdown the SGLang engine."""
169142 if hasattr (self , "engine" ):
0 commit comments