2020from vllm_omni .distributed .omni_connectors .adapter import compute_talker_prompt_ids_length , try_send_via_connector
2121from vllm_omni .distributed .ray_utils .utils import try_close_ray
2222from vllm_omni .engine .input_processor import OmniInputProcessor
23+ from vllm_omni .entrypoints .cfg_companion_tracker import CfgCompanionTracker
2324from vllm_omni .entrypoints .client_request_state import ClientRequestState
2425from vllm_omni .entrypoints .omni import OmniBase
2526from vllm_omni .entrypoints .omni_stage import OmniStage
2829from vllm_omni .entrypoints .utils import (
2930 get_final_stage_id_for_e2e ,
3031)
31- from vllm_omni .inputs .data import OmniPromptType , OmniSamplingParams
32+ from vllm_omni .inputs .data import OmniDiffusionSamplingParams , OmniPromptType , OmniSamplingParams
3233
3334# Internal imports (our code)
3435from vllm_omni .lora .request import LoRARequest
@@ -125,6 +126,9 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None:
125126 # Used to avoid race condition between output_handler and collective_rpc
126127 self ._rpc_results : dict [int , dict [str , dict [str , Any ]]] = {}
127128
129+ # CFG companion → parent request ID mapping for output routing
130+ self ._companion_to_parent : dict [str , str ] = {}
131+
128132 super ().__init__ (model , ** kwargs )
129133
130134 # Register weak reference cleanup (called on garbage collection)
@@ -389,13 +393,38 @@ async def generate(
389393 req_state = ClientRequestState (request_id )
390394 req_state .metrics = metrics
391395 self .request_states [request_id ] = req_state
396+
397+ # Ensure modalities is in the prompt dict for CFG expansion
398+ # (offline path includes it; online serving passes it separately)
399+ if isinstance (prompt , dict ) and output_modalities and "modalities" not in prompt :
400+ prompt ["modalities" ] = output_modalities
401+
402+ # CFG companion tracking (prompt expansion + lifecycle management)
403+ cfg = CfgCompanionTracker (
404+ prompt_expand_func = getattr (self .stage_list [0 ], "prompt_expand_func" , None ),
405+ stage0_sampling_params = sampling_params_list [0 ],
406+ )
407+ expanded_companions = cfg .expand_prompts ({request_id : prompt })
408+
392409 sp0 : SamplingParams = sampling_params_list [0 ] # type: ignore[index]
393410 task = {
394411 "request_id" : request_id ,
395412 "engine_inputs" : prompt ,
396413 "sampling_params" : sp0 ,
397414 }
398415 self .stage_list [0 ].submit (task )
416+
417+ # Submit CFG companion requests to stage-0
418+ if cfg .is_active :
419+ for companion_id , companion_prompt in expanded_companions :
420+ self ._companion_to_parent [companion_id ] = request_id
421+ companion_task = {
422+ "request_id" : companion_id ,
423+ "engine_inputs" : companion_prompt ,
424+ "sampling_params" : cfg .stage0_sampling_params ,
425+ }
426+ self .stage_list [0 ].submit (companion_task )
427+
399428 metrics .stage_first_ts [0 ] = metrics .stage_first_ts [0 ] or time .time ()
400429 _req_start_ts [request_id ] = time .time ()
401430 logger .info (
@@ -421,6 +450,7 @@ async def generate(
421450 final_stage_id_for_e2e ,
422451 sampling_params_list ,
423452 prompt ,
453+ cfg = cfg ,
424454 ):
425455 yield output
426456
@@ -440,6 +470,9 @@ async def generate(
440470 logger .exception (f"[{ self ._name } ] Request { request_id } Failed to finalized/build/log summary: { e } " )
441471 finally :
442472 self .request_states .pop (request_id , None )
473+ if cfg .is_active :
474+ for cid in cfg .get_companion_request_ids (request_id ).values ():
475+ self ._companion_to_parent .pop (cid , None )
443476 except (asyncio .CancelledError , GeneratorExit ):
444477 await self .abort (request_id )
445478 logger .info ("[AsyncOrchestrator] Request %s aborted." , request_id )
@@ -603,12 +636,29 @@ async def _process_sequential_results(
603636 final_stage_id_for_e2e : int ,
604637 sampling_params_list : list [SamplingParams ],
605638 prompt : Any ,
639+ cfg : CfgCompanionTracker | None = None ,
606640 ) -> AsyncGenerator [OmniRequestOutput , None ]:
607641 for stage_id , stage in enumerate (self .stage_list [: final_stage_id_for_e2e + 1 ]):
642+ cfg_stage0 = stage_id == 0 and cfg is not None and cfg .is_active
608643 finished = False
609- while not finished :
644+
645+ while True :
646+ if finished and (
647+ not cfg_stage0 or cfg .all_companions_done (request_id ) or cfg .is_parent_failed (request_id )
648+ ):
649+ break
650+
610651 result = await req_state .queue .get ()
611- assert stage_id == req_state .stage_id
652+
653+ if cfg is not None and cfg .is_companion (result .get ("request_id" , "" )):
654+ if cfg_stage0 :
655+ rid = result .get ("request_id" )
656+ if "error" in result :
657+ cfg .on_companion_error (rid )
658+ else :
659+ cfg .on_companion_completed (rid )
660+ continue
661+
612662 engine_outputs , finished , output_to_yield = self ._process_single_result (
613663 result ,
614664 stage ,
@@ -629,6 +679,16 @@ async def _process_sequential_results(
629679 next_inputs = next_stage .process_engine_inputs (self .stage_list , prompt )
630680 sp_next : SamplingParams = sampling_params_list [next_stage_id ]
631681
682+ if cfg is not None and cfg .is_active and not cfg .is_parent_failed (request_id ):
683+ if isinstance (sp_next , OmniDiffusionSamplingParams ):
684+ sp_next = copy .deepcopy (sp_next )
685+ sp_next .cfg_kv_request_ids = cfg .get_companion_request_ids (request_id )
686+ logger .info (
687+ "Attaching cfg_kv_request_ids=%s to request %s" ,
688+ sp_next .cfg_kv_request_ids ,
689+ request_id ,
690+ )
691+
632692 # Check if we have a connector for this edge
633693 connector_key = (str (stage_id ), str (next_stage_id ))
634694 connector = self .connectors .get (connector_key )
@@ -747,6 +807,7 @@ def _run_output_handler(self) -> None:
747807
748808 stage_list = self .stage_list
749809 request_states = self .request_states
810+ companion_to_parent = self ._companion_to_parent
750811
751812 async def output_handler ():
752813 try :
@@ -773,6 +834,10 @@ async def output_handler():
773834 continue
774835 req_id = result .get ("request_id" )
775836 req_state = request_states .get (req_id )
837+ if req_state is None :
838+ parent_id = companion_to_parent .get (req_id )
839+ if parent_id is not None :
840+ req_state = request_states .get (parent_id )
776841 if req_state is None :
777842 logger .debug (
778843 f"[{ self ._name } ] Request may have been aborted; \
0 commit comments