@@ -162,10 +162,12 @@ async def initialize(self, properties: dict):
162162 self .session_manager : SessionManager = SessionManager (properties )
163163 self .initialized = True
164164
165- def _get_custom_formatter (self , adapter_name : Optional [str ] = None ) -> bool :
165+ def _get_custom_formatter (self ,
166+ adapter_name : Optional [str ] = None ) -> bool :
166167 """Check if a custom output formatter exists for the adapter or base model."""
167168 if adapter_name :
168- adapter_formatter = self .get_adapter_formatter_handler (adapter_name )
169+ adapter_formatter = self .get_adapter_formatter_handler (
170+ adapter_name )
169171 if adapter_formatter and adapter_formatter .output_formatter :
170172 return True
171173 return self .output_formatter is not None
@@ -263,7 +265,9 @@ async def check_health(self):
263265 logger .fatal ("vLLM engine is dead, terminating process" )
264266 kill_process_tree (os .getpid ())
265267
266- async def inference (self , inputs : Input ) -> Union [Output , AsyncGenerator [Output , None ]]:
268+ async def inference (
269+ self ,
270+ inputs : Input ) -> Union [Output , AsyncGenerator [Output , None ]]:
267271 await self .check_health ()
268272 try :
269273 processed_request = self .preprocess_request (inputs )
@@ -281,10 +285,12 @@ async def inference(self, inputs: Input) -> Union[Output, AsyncGenerator[Output,
281285 processed_request .vllm_request )
282286
283287 # Check if custom formatter exists (applies to both streaming and non-streaming)
284- custom_formatter = self ._get_custom_formatter (processed_request .adapter_name )
288+ custom_formatter = self ._get_custom_formatter (
289+ processed_request .adapter_name )
285290
286291 if isinstance (response , types .AsyncGeneratorType ):
287- return self ._handle_streaming_response (response , processed_request , custom_formatter )
292+ return self ._handle_streaming_response (response , processed_request ,
293+ custom_formatter )
288294
289295 # Non-streaming response
290296 if custom_formatter :
@@ -296,32 +302,34 @@ async def inference(self, inputs: Input) -> Union[Output, AsyncGenerator[Output,
296302 elif hasattr (formatted_response , 'model_dump' ):
297303 formatted_response = formatted_response .model_dump ()
298304 return create_non_stream_output (formatted_response )
299-
305+
300306 # LMI formatter for non-streaming
301307 return processed_request .non_stream_output_formatter (
302308 response ,
303309 request = processed_request .vllm_request ,
304310 tokenizer = self .tokenizer ,
305311 )
306312
307- async def _handle_streaming_response (self , response , processed_request , custom_formatter ):
313+ async def _handle_streaming_response (self , response , processed_request ,
314+ custom_formatter ):
308315 """Handle streaming responses as an async generator"""
309316 if custom_formatter :
310317 # Custom formatter: apply to each chunk and yield directly
311318 async for chunk in response :
312319 formatted_chunk = self .apply_output_formatter (
313320 chunk , adapter_name = processed_request .adapter_name )
314- yield create_stream_chunk_output (formatted_chunk , last_chunk = False )
321+ yield create_stream_chunk_output (formatted_chunk ,
322+ last_chunk = False )
315323 yield create_stream_chunk_output ("" , last_chunk = True )
316324 else :
317325 # LMI formatter for streaming
318326 async for output in handle_streaming_response (
319- response ,
320- processed_request .stream_output_formatter ,
321- request = processed_request .vllm_request ,
322- accumulate_chunks = processed_request .accumulate_chunks ,
323- include_prompt = processed_request .include_prompt ,
324- tokenizer = self .tokenizer ,
327+ response ,
328+ processed_request .stream_output_formatter ,
329+ request = processed_request .vllm_request ,
330+ accumulate_chunks = processed_request .accumulate_chunks ,
331+ include_prompt = processed_request .include_prompt ,
332+ tokenizer = self .tokenizer ,
325333 ):
326334 yield output
327335
0 commit comments