@@ -1467,7 +1467,7 @@ async def generation_single_request(task: dict[str, Any]):
14671467 batch_request_ids , batch_request_outputs , _gen_ms_list , batch_metrics
14681468 ):
14691469 try :
1470- r_outputs = [output ]
1470+ r_outputs = [output_strip ( output , omni_stage ) ]
14711471 use_shm , payload = maybe_dump_to_shm (r_outputs , shm_threshold_bytes )
14721472 if use_shm :
14731473 out_q .put (
@@ -1553,3 +1553,32 @@ def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float):
15531553 from vllm_omni .entrypoints .log_utils import StageStats
15541554
15551555 return StageStats (total_token = _agg_total_tokens , total_gen_time = _agg_total_gen_time_ms )
1556+
1557+
1558+ def output_strip (r_output : RequestOutput | OmniRequestOutput , omni_stage : OmniStage ):
1559+ """
1560+ Strip unnecessary multimodal outputs from stages results,
1561+ in order to:
1562+ - reduce memory usage
1563+ - reduce transfer & serialization overhead
1564+ """
1565+
1566+ # check multimodal data is required by stage output config.
1567+ if omni_stage .final_output and omni_stage .final_output_type != "text" :
1568+ return r_output
1569+
1570+ # If the request has already finished, should not be altered.
1571+ if getattr (r_output , "finished" , False ):
1572+ return r_output
1573+
1574+ mm_output = getattr (r_output , "multimodal_output" , None )
1575+ if mm_output is not None :
1576+ r_output .multimodal_output = {}
1577+
1578+ outputs = getattr (r_output , "outputs" , None )
1579+ if outputs is not None :
1580+ for out in outputs :
1581+ if getattr (out , "multimodal_output" , None ):
1582+ out .multimodal_output = {}
1583+
1584+ return r_output
0 commit comments