Skip to content

Commit ca1d14e

Browse files
committed
[Bugfix] Enforce --max-generated-image-size on /v1/images/generations (vllm-project#2599)
Signed-off-by: Nick Cao <ncao@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> (cherry picked from commit 3bd8a52) Signed-off-by: David Chen <530634352@qq.com>
1 parent 73d31ed commit ca1d14e

File tree

2 files changed

+49
-21
lines changed

2 files changed

+49
-21
lines changed

tests/entrypoints/openai_api/test_image_server.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_client(mock_async_diffusion):
172172
app.state.diffusion_model_name = "Qwen/Qwen-Image" # For models endpoint
173173
app.state.args = Namespace(
174174
default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}',
175-
max_generated_image_size=4096, # 64*64
175+
max_generated_image_size=1024 * 1792,
176176
)
177177

178178
return TestClient(app)
@@ -239,7 +239,7 @@ def async_omni_stage_configs_only_client():
239239
# AsyncOmni exposes stage_configs on the engine instance.
240240
app.state.args = Namespace(
241241
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
242-
max_generated_image_size=4096, # 64*64
242+
max_generated_image_size=1024 * 1792,
243243
)
244244
return TestClient(app)
245245

@@ -386,6 +386,18 @@ def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only
386386
assert len(captured) == 2
387387

388388

389+
def test_generate_images_max_size_rejected(async_omni_test_client):
390+
"""Test that a size exceeding max_generated_image_size returns 400."""
391+
response = async_omni_test_client.post(
392+
"/v1/images/generations",
393+
json={
394+
"prompt": "a cat",
395+
"size": "2048x2048", # 4,194,304 pixels > max_generated_image_size (1,048,576)
396+
},
397+
)
398+
assert response.status_code == 400
399+
400+
389401
def test_generate_multiple_images(test_client):
390402
"""Test generating multiple images"""
391403
response = test_client.post(
@@ -976,12 +988,13 @@ def test_image_edit_parameter_default_single_stage(test_client):
976988
assert captured_sampling_params.num_inference_steps == 4
977989
assert captured_sampling_params.guidance_scale == 7.5
978990

991+
# Size exceeding max_generated_image_size (1024*1792) returns 400
979992
response = test_client.post(
980993
"/v1/images/edits",
981994
files=[("image", img_bytes_1)],
982995
data={
983996
"prompt": "hello world.",
984-
"size": "96x96",
997+
"size": "2048x2048",
985998
},
986999
)
9871000
assert response.status_code == 400

vllm_omni/entrypoints/openai/api_server.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,10 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
12991299
size_str = f"{width}x{height}"
13001300
else:
13011301
size_str = "model default"
1302+
1303+
app_state_args = getattr(raw_request.app.state, "args", None)
1304+
_check_max_generated_image_size(app_state_args, width, height)
1305+
13021306
_update_if_not_none(gen_params, "width", width)
13031307
_update_if_not_none(gen_params, "height", height)
13041308

@@ -1484,7 +1488,6 @@ async def edit_images(
14841488
)
14851489

14861490
# 3.3 Parse and add size if provided
1487-
max_generated_image_size = getattr(app_state_args, "max_generated_image_size", None)
14881491
width, height = None, None
14891492
if size.lower() == "auto":
14901493
if resolution is None:
@@ -1494,23 +1497,7 @@ async def edit_images(
14941497
else:
14951498
width, height = parse_size(size)
14961499

1497-
# Check max_generated_image_size
1498-
if max_generated_image_size is not None:
1499-
if width is not None and height is not None:
1500-
if width * height > max_generated_image_size:
1501-
raise HTTPException(
1502-
status_code=HTTPStatus.BAD_REQUEST.value,
1503-
detail=f"Requested image size {width}x{height} exceeds the maximum allowed "
1504-
f"size of {max_generated_image_size} pixels.",
1505-
)
1506-
elif resolution is not None:
1507-
# When resolution is set, the output size is resolution * resolution
1508-
if resolution * resolution > max_generated_image_size:
1509-
raise HTTPException(
1510-
status_code=HTTPStatus.BAD_REQUEST.value,
1511-
detail=f"Requested resolution {resolution} (max {resolution}x{resolution} pixels) "
1512-
f"exceeds the maximum allowed size of {max_generated_image_size} pixels.",
1513-
)
1500+
_check_max_generated_image_size(app_state_args, width, height, resolution)
15141501

15151502
size_str = f"{width}x{height}" if width is not None and height is not None else "auto"
15161503
_update_if_not_none(gen_params, "width", width)
@@ -1709,6 +1696,34 @@ async def _generate_with_async_omni(
17091696
return result
17101697

17111698

1699+
def _check_max_generated_image_size(
1700+
app_state_args: Any,
1701+
width: int | None,
1702+
height: int | None,
1703+
resolution: int | None = None,
1704+
) -> None:
1705+
"""Raise 400 if the requested image size exceeds --max-generated-image-size."""
1706+
max_generated_image_size = getattr(app_state_args, "max_generated_image_size", None)
1707+
# Check max_generated_image_size
1708+
if max_generated_image_size is None:
1709+
return
1710+
if width is not None and height is not None:
1711+
if width * height > max_generated_image_size:
1712+
raise HTTPException(
1713+
status_code=HTTPStatus.BAD_REQUEST.value,
1714+
detail=f"Requested image size {width}x{height} exceeds the maximum allowed "
1715+
f"size of {max_generated_image_size} pixels.",
1716+
)
1717+
elif resolution is not None:
1718+
# When resolution is set, the output size is resolution * resolution
1719+
if resolution * resolution > max_generated_image_size:
1720+
raise HTTPException(
1721+
status_code=HTTPStatus.BAD_REQUEST.value,
1722+
detail=f"Requested resolution {resolution} (max {resolution}x{resolution} pixels) "
1723+
f"exceeds the maximum allowed size of {max_generated_image_size} pixels.",
1724+
)
1725+
1726+
17121727
def _update_if_not_none(object: Any, key: str, val: Any) -> None:
17131728
if val is not None:
17141729
setattr(object, key, val)

0 commit comments

Comments
 (0)