Skip to content

Commit 301ceb2

Browse files
committed
feat: use async task polling for image generation
1 parent 39aa22c commit 301ceb2

File tree

4 files changed

+112
-32
lines changed

4 files changed

+112
-32
lines changed

src/modelscope_mcp_server/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@
1111
# Default timeout for requests
1212
DEFAULT_API_TIMEOUT_SECONDS = 5
1313
DEFAULT_IMAGE_GENERATION_TIMEOUT_SECONDS = 300
14+
15+
# Task polling interval (seconds)
16+
DEFAULT_TASK_POLL_INTERVAL_SECONDS = 5

src/modelscope_mcp_server/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
DEFAULT_IMAGE_TO_IMAGE_MODEL,
1010
DEFAULT_MODELSCOPE_API_INFERENCE_DOMAIN,
1111
DEFAULT_MODELSCOPE_DOMAIN,
12+
DEFAULT_TASK_POLL_INTERVAL_SECONDS,
1213
DEFAULT_TEXT_TO_IMAGE_MODEL,
1314
)
1415

@@ -57,6 +58,12 @@ class Settings(BaseSettings):
5758
description="Default timeout for image generation requests",
5859
)
5960

61+
# Task polling
62+
task_poll_interval_seconds: int = Field(
63+
default=DEFAULT_TASK_POLL_INTERVAL_SECONDS,
64+
description="Polling interval in seconds when waiting for async tasks",
65+
)
66+
6067
# Logging settings
6168
log_level: str = Field(default="INFO", description="Logging level")
6269

src/modelscope_mcp_server/tools/aigc.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Provides MCP tools for text-to-image generation, etc.
44
"""
55

6+
import time
67
from typing import Annotated
78

89
from fastmcp import FastMCP
@@ -77,7 +78,8 @@ async def generate_image(
7778
if not settings.is_api_token_configured():
7879
raise ValueError("API token is not set")
7980

80-
url = f"{settings.api_inference_domain}/v1/images/generations"
81+
# Step 1: submit async generation task
82+
submit_url = f"{settings.api_inference_domain}/v1/images/generations"
8183

8284
payload = {
8385
"model": model,
@@ -87,21 +89,44 @@ async def generate_image(
8789
if generation_type == GenerationType.IMAGE_TO_IMAGE and image_url:
8890
payload["image_url"] = image_url
8991

90-
response = default_client.post(
91-
url, json_data=payload, timeout=settings.default_image_generation_timeout_seconds
92+
submit_response = default_client.post(
93+
submit_url,
94+
json_data=payload,
95+
timeout=settings.default_image_generation_timeout_seconds,
96+
headers={"X-ModelScope-Async-Mode": "true"},
9297
)
9398

94-
images_data = response.get("images", [])
95-
96-
if len(images_data) == 0:
97-
raise Exception(f"No images found in response: {response}")
98-
99-
generated_image_url = images_data[0].get("url", "")
100-
if len(generated_image_url) == 0:
101-
raise Exception(f"No image URL found in response: {response}")
99+
task_id = submit_response.get("task_id")
100+
if not task_id:
101+
raise Exception(f"No task_id found in response: {submit_response}")
102+
103+
# Step 2: poll task result until succeed/failed or timeout
104+
start_time = time.time()
105+
task_url = f"{settings.api_inference_domain}/v1/tasks/{task_id}"
106+
while True:
107+
# timeout check
108+
if time.time() - start_time > settings.default_image_generation_timeout_seconds:
109+
raise TimeoutError("Image generation timed out - please try again later")
110+
111+
task_result = default_client.get(
112+
task_url,
113+
timeout=settings.default_api_timeout_seconds,
114+
headers={"X-ModelScope-Task-Type": "image_generation"},
115+
)
102116

103-
return ImageGenerationResult(
104-
type=generation_type,
105-
model=model,
106-
image_url=generated_image_url,
107-
)
117+
status = task_result.get("task_status")
118+
if status == "SUCCEED":
119+
output_images = task_result.get("output_images") or []
120+
if not output_images:
121+
raise Exception(f"No output images found in task result: {task_result}")
122+
generated_image_url = output_images[0]
123+
return ImageGenerationResult(
124+
type=generation_type,
125+
model=model,
126+
image_url=generated_image_url,
127+
)
128+
if status == "FAILED":
129+
raise Exception("Image Generation Failed.")
130+
131+
logger.info(f"Image generation task {task_id} is {status}, waiting for next poll...")
132+
time.sleep(settings.task_poll_interval_seconds)

tests/tools/test_generate_image.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,18 @@
77

88

99
async def test_text_to_image_generation_success(mcp_server, mocker):
10-
"""Test successful text-to-image generation."""
11-
mock_response_data = {"images": [{"url": "https://example.com/generated_image.jpg"}]}
12-
mock_post = mocker.patch("modelscope_mcp_server.client.default_client.post", return_value=mock_response_data)
10+
"""Test successful text-to-image generation with async polling."""
11+
mock_post = mocker.patch(
12+
"modelscope_mcp_server.client.default_client.post",
13+
return_value={"task_id": "task-text-1"},
14+
)
15+
mock_get = mocker.patch(
16+
"modelscope_mcp_server.client.default_client.get",
17+
return_value={
18+
"task_status": "SUCCEED",
19+
"output_images": ["https://example.com/generated_image.jpg"],
20+
},
21+
)
1322

1423
async with Client(mcp_server) as client:
1524
result = await client.call_tool(
@@ -32,12 +41,22 @@ async def test_text_to_image_generation_success(mcp_server, mocker):
3241
)
3342

3443
mock_post.assert_called_once()
44+
mock_get.assert_called()
3545

3646

3747
async def test_image_to_image_generation_success(mcp_server, mocker):
38-
"""Test successful image-to-image generation."""
39-
mock_response_data = {"images": [{"url": "https://example.com/modified_image.jpg"}]}
40-
mock_post = mocker.patch("modelscope_mcp_server.client.default_client.post", return_value=mock_response_data)
48+
"""Test successful image-to-image generation with async polling."""
49+
mock_post = mocker.patch(
50+
"modelscope_mcp_server.client.default_client.post",
51+
return_value={"task_id": "task-image-1"},
52+
)
53+
mock_get = mocker.patch(
54+
"modelscope_mcp_server.client.default_client.get",
55+
return_value={
56+
"task_status": "SUCCEED",
57+
"output_images": ["https://example.com/modified_image.jpg"],
58+
},
59+
)
4160

4261
async with Client(mcp_server) as client:
4362
result = await client.call_tool(
@@ -61,12 +80,22 @@ async def test_image_to_image_generation_success(mcp_server, mocker):
6180
)
6281

6382
mock_post.assert_called_once()
83+
mock_get.assert_called()
6484

6585

6686
async def test_generate_image_with_default_model(mcp_server, mocker):
67-
"""Test image generation with default model when no model is specified."""
68-
mock_response_data = {"images": [{"url": "https://example.com/default_model_image.jpg"}]}
69-
mocker.patch("modelscope_mcp_server.client.default_client.post", return_value=mock_response_data)
87+
"""Test image generation with default model when no model is specified (async)."""
88+
mocker.patch(
89+
"modelscope_mcp_server.client.default_client.post",
90+
return_value={"task_id": "task-default-1"},
91+
)
92+
mocker.patch(
93+
"modelscope_mcp_server.client.default_client.get",
94+
return_value={
95+
"task_status": "SUCCEED",
96+
"output_images": ["https://example.com/default_model_image.jpg"],
97+
},
98+
)
7099

71100
async with Client(mcp_server) as client:
72101
result = await client.call_tool(
@@ -141,10 +170,10 @@ async def test_generate_image_timeout_error(mcp_server, mocker):
141170

142171

143172
async def test_generate_image_malformed_response(mcp_server, mocker):
144-
"""Test handling of malformed API response."""
173+
"""Test handling of malformed API response on submit (missing task_id)."""
145174
malformed_response_data = {
146175
"result": "success",
147-
# Missing 'images' field
176+
# Missing 'task_id' field
148177
}
149178
mocker.patch("modelscope_mcp_server.client.default_client.post", return_value=malformed_response_data)
150179

@@ -159,13 +188,22 @@ async def test_generate_image_malformed_response(mcp_server, mocker):
159188
)
160189

161190
print(f"✅ Malformed response error handled correctly: {exc_info.value}")
162-
assert "No images found in response" in str(exc_info.value)
191+
assert "No task_id found in response" in str(exc_info.value)
163192

164193

165194
async def test_generate_image_request_parameters(mcp_server, mocker):
166-
"""Test that the correct parameters are sent in the request."""
167-
mock_response_data = {"images": [{"url": "https://example.com/test_image.jpg"}]}
168-
mock_post = mocker.patch("modelscope_mcp_server.client.default_client.post", return_value=mock_response_data)
195+
"""Test that the correct parameters are sent in the request (async)."""
196+
mock_post = mocker.patch(
197+
"modelscope_mcp_server.client.default_client.post",
198+
return_value={"task_id": "task-param-1"},
199+
)
200+
mock_get = mocker.patch(
201+
"modelscope_mcp_server.client.default_client.get",
202+
return_value={
203+
"task_status": "SUCCEED",
204+
"output_images": ["https://example.com/test_image.jpg"],
205+
},
206+
)
169207

170208
async with Client(mcp_server) as client:
171209
await client.call_tool(
@@ -177,7 +215,7 @@ async def test_generate_image_request_parameters(mcp_server, mocker):
177215
},
178216
)
179217

180-
# Verify the request was called with correct parameters
218+
# Verify the submit request was called with correct parameters
181219
mock_post.assert_called_once()
182220
call_args = mock_post.call_args
183221

@@ -194,4 +232,11 @@ async def test_generate_image_request_parameters(mcp_server, mocker):
194232
# Check timeout
195233
assert call_args.kwargs["timeout"] == 300
196234

235+
# Check headers include async mode
236+
headers = call_args.kwargs.get("headers", {})
237+
assert headers.get("X-ModelScope-Async-Mode") == "true"
238+
239+
# Verify polling was performed
240+
mock_get.assert_called()
241+
197242
print("✅ Request parameters verified correctly")

0 commit comments

Comments
 (0)