|
1 | 1 | import asyncio |
| 2 | +import aiohttp |
2 | 3 | from pathlib import Path |
3 | 4 | from timeit import default_timer as timer |
4 | 5 | import pytest |
@@ -57,16 +58,20 @@ async def runner(): |
57 | 58 | return results[0] |
58 | 59 |
|
59 | 60 |
|
60 | | -def create_simple_workflow(prompt="fluffy ball", input: Image | None = None): |
| 61 | +def create_simple_workflow(prompt="fluffy ball", input: Image | Extent | None = None): |
61 | 62 | start = 0 |
62 | | - images = ImageInput.from_extent(Extent(512, 512)) |
63 | | - if input: |
| 63 | + if isinstance(input, Image): |
| 64 | + images = ImageInput.from_extent(input.extent) |
64 | 65 | images.initial_image = input |
65 | 66 | images.hires_image = input |
66 | 67 | start = 4 |
| 68 | + elif isinstance(input, Extent): |
| 69 | + images = ImageInput.from_extent(input) |
| 70 | + else: |
| 71 | + images = ImageInput.from_extent(Extent(512, 512)) |
67 | 72 |
|
68 | 73 | return WorkflowInput( |
69 | | - WorkflowKind.generate if input is None else WorkflowKind.refine, |
| 74 | + WorkflowKind.generate if images.initial_image is None else WorkflowKind.refine, |
70 | 75 | images=images, |
71 | 76 | models=CheckpointInput("dreamshaper_8.safetensors"), |
72 | 77 | sampling=SamplingInput( |
@@ -230,3 +235,38 @@ async def main(): |
230 | 235 | end_time = timer() |
231 | 236 | duration = end_time - start_time |
232 | 237 | print(f"Completed 5 x 2 jobs in {duration:.2f} seconds", end=" ") |
| 238 | + |
| 239 | + |
| 240 | +def test_error_workflow(qtapp, cloud_client: CloudClient): |
| 241 | + workflow = create_simple_workflow() |
| 242 | + workflow.kind = WorkflowKind.refine # Error: refine requires an input image |
| 243 | + with pytest.raises(Exception, match="failed"): |
| 244 | + run_and_save(qtapp, cloud_client, workflow, "error_workflow") |
| 245 | + |
| 246 | + |
| 247 | +def test_job_timeout(pytestconfig, qtapp, cloud_service: CloudService): |
| 248 | + if not cloud_service.enabled: |
| 249 | + pytest.skip("Cloud service not running") |
| 250 | + |
| 251 | + async def main(): |
| 252 | + user = await cloud_service.create_user("timeout-tester") |
| 253 | + client = await CloudClient.connect(cloud_service.url, user["token"]) |
| 254 | + big_workflow = create_simple_workflow(input=Extent(2048, 1536)) |
| 255 | + |
| 256 | + await cloud_service.set_worker_job_timeout(5) |
| 257 | + |
| 258 | + with pytest.raises(Exception, match="timeout"): |
| 259 | + await receive_images(client, big_workflow) |
| 260 | + |
| 261 | + # Worker should be restarted and accept new jobs |
| 262 | + for _attempt in range(5): |
| 263 | + try: |
| 264 | + await cloud_service.set_worker_job_timeout(480) |
| 265 | + break |
| 266 | + except aiohttp.ClientConnectionError: |
| 267 | + await asyncio.sleep(2) # Wait for worker to be back up |
| 268 | + small_workflow = create_simple_workflow() |
| 269 | + images = await receive_images(client, small_workflow) |
| 270 | + assert len(images) == 2 |
| 271 | + |
| 272 | + qtapp.run(main()) |
0 commit comments