Skip to content

Commit e0b7105

Browse files
authored
[BugFix]: Fix bagel online inference bug (vllm-project#1804)
Signed-off-by: princepride <wangzhipeng628@gmail.com>
1 parent 26b3dde commit e0b7105

File tree

5 files changed

+115
-13
lines changed

5 files changed

+115
-13
lines changed

tests/e2e/offline_inference/test_zimage_parallelism.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
from pathlib import Path
1818

1919
import numpy as np
20-
import pytest
20+
21+
# import pytest
2122
import torch
2223
from PIL import Image
2324
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
2425

25-
from tests.utils import DeviceMemoryMonitor, hardware_test
26+
# from tests.utils import DeviceMemoryMonitor, hardware_test
27+
from tests.utils import DeviceMemoryMonitor
2628
from vllm_omni import Omni
2729
from vllm_omni.diffusion.data import DiffusionParallelConfig
2830
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

vllm_omni/entrypoints/async_omni.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length, try_send_via_connector
2121
from vllm_omni.distributed.ray_utils.utils import try_close_ray
2222
from vllm_omni.engine.input_processor import OmniInputProcessor
23+
from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker
2324
from vllm_omni.entrypoints.client_request_state import ClientRequestState
2425
from vllm_omni.entrypoints.omni import OmniBase
2526
from vllm_omni.entrypoints.omni_stage import OmniStage
@@ -28,7 +29,7 @@
2829
from vllm_omni.entrypoints.utils import (
2930
get_final_stage_id_for_e2e,
3031
)
31-
from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams
32+
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams
3233

3334
# Internal imports (our code)
3435
from vllm_omni.lora.request import LoRARequest
@@ -125,6 +126,9 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None:
125126
# Used to avoid race condition between output_handler and collective_rpc
126127
self._rpc_results: dict[int, dict[str, dict[str, Any]]] = {}
127128

129+
# CFG companion → parent request ID mapping for output routing
130+
self._companion_to_parent: dict[str, str] = {}
131+
128132
super().__init__(model, **kwargs)
129133

130134
# Register weak reference cleanup (called on garbage collection)
@@ -389,13 +393,38 @@ async def generate(
389393
req_state = ClientRequestState(request_id)
390394
req_state.metrics = metrics
391395
self.request_states[request_id] = req_state
396+
397+
# Ensure modalities is in the prompt dict for CFG expansion
398+
# (offline path includes it; online serving passes it separately)
399+
if isinstance(prompt, dict) and output_modalities and "modalities" not in prompt:
400+
prompt["modalities"] = output_modalities
401+
402+
# CFG companion tracking (prompt expansion + lifecycle management)
403+
cfg = CfgCompanionTracker(
404+
prompt_expand_func=getattr(self.stage_list[0], "prompt_expand_func", None),
405+
stage0_sampling_params=sampling_params_list[0],
406+
)
407+
expanded_companions = cfg.expand_prompts({request_id: prompt})
408+
392409
sp0: SamplingParams = sampling_params_list[0] # type: ignore[index]
393410
task = {
394411
"request_id": request_id,
395412
"engine_inputs": prompt,
396413
"sampling_params": sp0,
397414
}
398415
self.stage_list[0].submit(task)
416+
417+
# Submit CFG companion requests to stage-0
418+
if cfg.is_active:
419+
for companion_id, companion_prompt in expanded_companions:
420+
self._companion_to_parent[companion_id] = request_id
421+
companion_task = {
422+
"request_id": companion_id,
423+
"engine_inputs": companion_prompt,
424+
"sampling_params": cfg.stage0_sampling_params,
425+
}
426+
self.stage_list[0].submit(companion_task)
427+
399428
metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time()
400429
_req_start_ts[request_id] = time.time()
401430
logger.info(
@@ -421,6 +450,7 @@ async def generate(
421450
final_stage_id_for_e2e,
422451
sampling_params_list,
423452
prompt,
453+
cfg=cfg,
424454
):
425455
yield output
426456

@@ -440,6 +470,9 @@ async def generate(
440470
logger.exception(f"[{self._name}] Request {request_id} Failed to finalized/build/log summary: {e}")
441471
finally:
442472
self.request_states.pop(request_id, None)
473+
if cfg.is_active:
474+
for cid in cfg.get_companion_request_ids(request_id).values():
475+
self._companion_to_parent.pop(cid, None)
443476
except (asyncio.CancelledError, GeneratorExit):
444477
await self.abort(request_id)
445478
logger.info("[AsyncOrchestrator] Request %s aborted.", request_id)
@@ -603,12 +636,29 @@ async def _process_sequential_results(
603636
final_stage_id_for_e2e: int,
604637
sampling_params_list: list[SamplingParams],
605638
prompt: Any,
639+
cfg: CfgCompanionTracker | None = None,
606640
) -> AsyncGenerator[OmniRequestOutput, None]:
607641
for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]):
642+
cfg_stage0 = stage_id == 0 and cfg is not None and cfg.is_active
608643
finished = False
609-
while not finished:
644+
645+
while True:
646+
if finished and (
647+
not cfg_stage0 or cfg.all_companions_done(request_id) or cfg.is_parent_failed(request_id)
648+
):
649+
break
650+
610651
result = await req_state.queue.get()
611-
assert stage_id == req_state.stage_id
652+
653+
if cfg is not None and cfg.is_companion(result.get("request_id", "")):
654+
if cfg_stage0:
655+
rid = result.get("request_id")
656+
if "error" in result:
657+
cfg.on_companion_error(rid)
658+
else:
659+
cfg.on_companion_completed(rid)
660+
continue
661+
612662
engine_outputs, finished, output_to_yield = self._process_single_result(
613663
result,
614664
stage,
@@ -629,6 +679,16 @@ async def _process_sequential_results(
629679
next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt)
630680
sp_next: SamplingParams = sampling_params_list[next_stage_id]
631681

682+
if cfg is not None and cfg.is_active and not cfg.is_parent_failed(request_id):
683+
if isinstance(sp_next, OmniDiffusionSamplingParams):
684+
sp_next = copy.deepcopy(sp_next)
685+
sp_next.cfg_kv_request_ids = cfg.get_companion_request_ids(request_id)
686+
logger.info(
687+
"Attaching cfg_kv_request_ids=%s to request %s",
688+
sp_next.cfg_kv_request_ids,
689+
request_id,
690+
)
691+
632692
# Check if we have a connector for this edge
633693
connector_key = (str(stage_id), str(next_stage_id))
634694
connector = self.connectors.get(connector_key)
@@ -747,6 +807,7 @@ def _run_output_handler(self) -> None:
747807

748808
stage_list = self.stage_list
749809
request_states = self.request_states
810+
companion_to_parent = self._companion_to_parent
750811

751812
async def output_handler():
752813
try:
@@ -773,6 +834,10 @@ async def output_handler():
773834
continue
774835
req_id = result.get("request_id")
775836
req_state = request_states.get(req_id)
837+
if req_state is None:
838+
parent_id = companion_to_parent.get(req_id)
839+
if parent_id is not None:
840+
req_state = request_states.get(parent_id)
776841
if req_state is None:
777842
logger.debug(
778843
f"[{self._name}] Request may have been aborted; \

vllm_omni/entrypoints/openai/serving_chat.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,17 +311,21 @@ async def create_chat_completion(
311311
negative_prompt = extra_body.get("negative_prompt")
312312

313313
engine_prompt_image: dict[str, Any] | None = None
314+
is_img2img = False
314315
if reference_images:
315316
# Best-effort decode first reference image for i2i.
316317
try:
317318
img_bytes = base64.b64decode(reference_images[0])
318319
img = Image.open(BytesIO(img_bytes))
319-
engine_prompt_image = {"image": img}
320+
engine_prompt_image = {"img2img": img}
321+
is_img2img = True
320322
except Exception:
321323
engine_prompt_image = None
322324

323325
# Override the prompts produced by chat-template preprocessing.
324326
tprompt: OmniTextPrompt = {"prompt": extracted_prompt}
327+
if is_img2img:
328+
tprompt["modalities"] = ["img2img"]
325329
if negative_prompt is not None:
326330
tprompt["negative_prompt"] = negative_prompt
327331
# GLM-Image's _call_hf_processor expects target_h/target_w in mm_processor_kwargs
@@ -490,10 +494,11 @@ async def _preprocess_chat(
490494
)
491495

492496
# Preserve a clean text prompt for downstream stages (e.g., GLM-Image diffusion).
493-
# For /v1/chat/completions, `request_prompt` is often the rendered chat template.
494-
# Diffusion models generally want the raw user caption instead.
495-
output_modalities = getattr(self.engine_client, "output_modalities", None)
496-
if output_modalities and ("image" in output_modalities):
497+
# For image generation, we want the raw user caption instead of a rendered template.
498+
# But for multimodal comprehension (img2text), we MUST keep the rendered prompt
499+
# containing image tokens.
500+
req_modalities = getattr(request, "modalities", [])
501+
if req_modalities and ("image" in req_modalities):
497502
messages_as_dicts: list[dict[str, Any]] = []
498503
for msg in messages:
499504
if hasattr(msg, "model_dump"):

vllm_omni/model_executor/models/bagel/bagel.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ def _get_subparsers(self):
202202

203203

204204
class OmniBagelMultiModalProcessor(BaseMultiModalProcessor[OmniBagelProcessingInfo]):
205+
IMG2IMG_PLACEHOLDER = "<|fim_middle|>"
206+
207+
def _cached_apply_hf_processor(self, inputs, timing_ctx):
208+
# img2img: prompt text must be modified based on mm data presence,
209+
# so text and mm data cannot be tokenized separately — bypass cache.
210+
if inputs.mm_data_items.get_all_counts().get("img2img", 0) > 0:
211+
return self._apply_hf_processor(inputs, timing_ctx)
212+
return super()._cached_apply_hf_processor(inputs, timing_ctx)
213+
205214
def _get_mm_fields_config(self, hf_inputs, hf_processor_mm_kwargs):
206215
return {
207216
"pixel_values": MultiModalFieldConfig.batched("image"),
@@ -218,6 +227,9 @@ def _call_hf_processor(
218227
has_image = "images" in mm_data
219228
has_img2img = "pixel_values_img2img" in mm_data
220229

230+
if has_img2img and self.IMG2IMG_PLACEHOLDER not in prompt:
231+
prompt = f"{self.IMG2IMG_PLACEHOLDER}{prompt}"
232+
221233
if has_image and has_img2img:
222234
outputs = BatchFeature()
223235

vllm_omni/model_executor/models/glm_image/glm_image_ar.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@
7878
MultiModalFieldConfig,
7979
MultiModalKwargsItems,
8080
)
81-
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
81+
from vllm.multimodal.parse import (
82+
ImageProcessorItems,
83+
MultiModalDataItems,
84+
MultiModalDataParser,
85+
)
8286
from vllm.multimodal.processing import (
8387
BaseDummyInputsBuilder,
8488
BaseMultiModalProcessor,
@@ -115,6 +119,15 @@ class GlmImagePixelInputs(TensorSchema):
115119
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
116120

117121

122+
class GlmImageDataParser(MultiModalDataParser):
123+
"""GLM-Image treats ``img2img`` input identically to ``image``."""
124+
125+
def _get_subparsers(self):
126+
parsers = super()._get_subparsers()
127+
parsers["img2img"] = self._parse_image_data
128+
return parsers
129+
130+
118131
class GlmImageProcessingInfo(BaseProcessingInfo):
119132
"""
120133
Processing information for GLM-Image model.
@@ -162,14 +175,19 @@ def get_hf_processor(self, **kwargs: object):
162175
**kwargs,
163176
)
164177

178+
def get_data_parser(self) -> GlmImageDataParser:
179+
return GlmImageDataParser(
180+
expected_hidden_size=self._get_expected_hidden_size(),
181+
)
182+
165183
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
166184
# GLM-Image is an image GENERATION model that supports:
167185
# - Text-to-image (t2i): no multimodal input needed
168186
# - Image-to-image (i2i): source images provided as input
169187
#
170188
# For i2i mode, we support up to 1 image as condition.
171-
# The model architecture supports multiple images but typical usage is 1.
172-
return {"image": 1}
189+
# "img2img" is an alias used by the serving layer; parsed as "image".
190+
return {"image": 1, "img2img": 1}
173191

174192
def get_num_image_tokens(
175193
self,

0 commit comments

Comments
 (0)