|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +""" |
| 5 | +End-to-end test for Bagel img2img generation. |
| 6 | +
|
| 7 | +This test validates that the Bagel model generates images from an input image |
| 8 | +and text prompt that match expected reference pixel values within a ±5 tolerance. |
| 9 | +
|
| 10 | +Equivalent to running: |
| 11 | + python3 examples/offline_inference/bagel/end2end.py \ |
| 12 | + --prompts "Change the grass color to red" \ |
| 13 | + --modality img2img --step 15 \ |
| 14 | + --image-path 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg |
| 15 | +""" |
| 16 | + |
| 17 | +import socket |
| 18 | +from pathlib import Path |
| 19 | +from typing import Any |
| 20 | + |
| 21 | +import pytest |
| 22 | +from PIL import Image |
| 23 | +from vllm.assets.image import ImageAsset |
| 24 | + |
| 25 | +from tests.utils import hardware_test |
| 26 | +from vllm_omni.entrypoints.omni import Omni |
| 27 | + |
| 28 | +# Reference pixel data extracted from the known-good output image |
| 29 | +# Generated with seed=52, num_inference_steps=15, |
| 30 | +# prompt='Change the grass color to red', |
| 31 | +# input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg |
| 32 | +REFERENCE_PIXELS = [ |
| 33 | + {"position": (100, 100), "rgb": (157, 172, 217)}, |
| 34 | + {"position": (400, 50), "rgb": (105, 144, 218)}, |
| 35 | + {"position": (700, 100), "rgb": (118, 159, 233)}, |
| 36 | + {"position": (150, 400), "rgb": (195, 34, 60)}, |
| 37 | + {"position": (512, 336), "rgb": (222, 214, 193)}, |
| 38 | + {"position": (700, 400), "rgb": (197, 15, 43)}, |
| 39 | + {"position": (100, 600), "rgb": (105, 13, 18)}, |
| 40 | + {"position": (400, 600), "rgb": (169, 33, 44)}, |
| 41 | + {"position": (700, 600), "rgb": (101, 86, 93)}, |
| 42 | + {"position": (256, 256), "rgb": (181, 202, 222)}, |
| 43 | +] |
| 44 | + |
| 45 | +PIXEL_TOLERANCE = 5 |
| 46 | + |
| 47 | +DEFAULT_PROMPT = "<|fim_middle|><|im_start|>Change the grass color to red<|im_end|>" |
| 48 | + |
| 49 | +EXPECTED_OUTPUT_SIZE = (1024, 672) |
| 50 | + |
| 51 | + |
| 52 | +def _load_input_image() -> Image.Image: |
| 53 | + """Load the test input image via vllm's ImageAsset.""" |
| 54 | + return ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB") |
| 55 | + |
| 56 | + |
| 57 | +def _find_free_port() -> int: |
| 58 | + """Find and return a free ephemeral port by binding to port 0.""" |
| 59 | + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| 60 | + s.bind(("127.0.0.1", 0)) |
| 61 | + s.listen(1) |
| 62 | + port = s.getsockname()[1] |
| 63 | + return port |
| 64 | + |
| 65 | + |
| 66 | +def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_steps: int = 15) -> list: |
| 67 | + """Configure sampling parameters for Bagel img2img generation. |
| 68 | +
|
| 69 | + Args: |
| 70 | + omni: The Omni instance to get default params from. |
| 71 | + max_tokens: Maximum tokens for the first stage. |
| 72 | + num_inference_steps: Number of inference steps for the diffusion stage. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + Configured sampling params list. |
| 76 | + """ |
| 77 | + params_list = omni.default_sampling_params_list |
| 78 | + params_list[0].max_tokens = max_tokens # type: ignore |
| 79 | + if len(params_list) > 1: |
| 80 | + params_list[1].num_inference_steps = num_inference_steps # type: ignore |
| 81 | + params_list[1].extra_args = { # type: ignore |
| 82 | + "cfg_text_scale": 4.0, |
| 83 | + "cfg_img_scale": 1.5, |
| 84 | + } |
| 85 | + return params_list |
| 86 | + |
| 87 | + |
| 88 | +def _extract_generated_image(omni_outputs: list) -> Image.Image | None: |
| 89 | + """Extract the generated image from Omni outputs. |
| 90 | +
|
| 91 | + Args: |
| 92 | + omni_outputs: List of outputs from omni.generate(). |
| 93 | +
|
| 94 | + Returns: |
| 95 | + The first generated PIL Image, or None if no image found. |
| 96 | + """ |
| 97 | + for req_output in omni_outputs: |
| 98 | + if images := getattr(req_output, "images", None): |
| 99 | + return images[0] |
| 100 | + if hasattr(req_output, "request_output") and req_output.request_output: |
| 101 | + for stage_out in req_output.request_output: |
| 102 | + if hasattr(stage_out, "images") and stage_out.images: |
| 103 | + return stage_out.images[0] |
| 104 | + return None |
| 105 | + |
| 106 | + |
| 107 | +def _validate_pixels( |
| 108 | + image: Image.Image, |
| 109 | + reference_pixels: list[dict[str, Any]] = REFERENCE_PIXELS, |
| 110 | + tolerance: int = PIXEL_TOLERANCE, |
| 111 | +) -> None: |
| 112 | + """Validate that image pixels match expected reference values. |
| 113 | +
|
| 114 | + Args: |
| 115 | + image: The PIL Image to validate. |
| 116 | + reference_pixels: List of dicts with 'position' (x, y) and 'rgb' (R, G, B). |
| 117 | + tolerance: Maximum allowed difference per color channel. |
| 118 | +
|
| 119 | + Raises: |
| 120 | + AssertionError: If any pixel differs beyond tolerance. |
| 121 | + """ |
| 122 | + for ref in reference_pixels: |
| 123 | + x, y = ref["position"] |
| 124 | + expected = ref["rgb"] |
| 125 | + actual = image.getpixel((x, y))[:3] |
| 126 | + assert all(abs(a - e) <= tolerance for a, e in zip(actual, expected)), ( |
| 127 | + f"Pixel mismatch at ({x}, {y}): expected {expected}, got {actual}" |
| 128 | + ) |
| 129 | + |
| 130 | + |
| 131 | +def _generate_bagel_img2img( |
| 132 | + omni: Omni, |
| 133 | + input_image: Image.Image, |
| 134 | + prompt: str = DEFAULT_PROMPT, |
| 135 | +) -> Image.Image: |
| 136 | + """Generate an image using Bagel model with img2img pipeline. |
| 137 | +
|
| 138 | + Args: |
| 139 | + omni: The Omni instance to use for generation. |
| 140 | + input_image: The input PIL Image for img2img. |
| 141 | + prompt: The text prompt for image editing. |
| 142 | +
|
| 143 | + Returns: |
| 144 | + The generated PIL Image. |
| 145 | +
|
| 146 | + Raises: |
| 147 | + AssertionError: If no image is generated or size is incorrect. |
| 148 | + """ |
| 149 | + params_list = _configure_sampling_params(omni) |
| 150 | + |
| 151 | + omni_outputs = list( |
| 152 | + omni.generate( |
| 153 | + prompts=[ |
| 154 | + { |
| 155 | + "prompt": prompt, |
| 156 | + "multi_modal_data": {"img2img": input_image}, |
| 157 | + "modalities": ["img2img"], |
| 158 | + } |
| 159 | + ], |
| 160 | + sampling_params_list=params_list, |
| 161 | + ) |
| 162 | + ) |
| 163 | + |
| 164 | + generated_image = _extract_generated_image(omni_outputs) |
| 165 | + assert generated_image is not None, "No images generated" |
| 166 | + assert generated_image.size == EXPECTED_OUTPUT_SIZE, f"Expected {EXPECTED_OUTPUT_SIZE}, got {generated_image.size}" |
| 167 | + |
| 168 | + return generated_image |
| 169 | + |
| 170 | + |
| 171 | +@pytest.mark.core_model |
| 172 | +@pytest.mark.diffusion |
| 173 | +@hardware_test(res={"cuda": "H100"}) |
| 174 | +def test_bagel_img2img_shared_memory_connector(): |
| 175 | + """Test Bagel img2img with shared memory connector.""" |
| 176 | + input_image = _load_input_image() |
| 177 | + config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml") |
| 178 | + omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300) |
| 179 | + |
| 180 | + try: |
| 181 | + generated_image = _generate_bagel_img2img(omni, input_image) |
| 182 | + _validate_pixels(generated_image) |
| 183 | + finally: |
| 184 | + omni.close() |
0 commit comments