Skip to content

Commit f9df268

Browse files
committed
Tests: timeout and error cases when using cloud service
1 parent f609afc commit f9df268

File tree

2 files changed

+61
-8
lines changed

2 files changed

+61
-8
lines changed

tests/conftest.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def __init__(self, loop: QtTestApp, enabled=True):
115115
self.worker_proc: asyncio.subprocess.Process | None = None
116116
self.worker_task: asyncio.Task | None = None
117117
self.worker_log = None
118+
self.worker_url = ""
119+
self.worker_secret = ""
118120
self.enabled = has_local_cloud and enabled
119121

120122
async def serve(self, process: asyncio.subprocess.Process, log_file):
@@ -158,11 +160,11 @@ async def launch_worker(self):
158160
config = self.dir / "pod" / "_var" / "worker.json"
159161
assert config.exists(), "Worker config not found"
160162
config_dict = json.loads(config.read_text(encoding="utf-8"))
161-
worker_url = config_dict["public_url"]
162-
admin_secret = config_dict["admin_secret"]
163+
self.worker_url = config_dict["public_url"]
164+
self.worker_secret = config_dict["admin_secret"]
163165
self.worker_log = open(self.log_dir / "worker.log", "w", encoding="utf-8")
164-
if await self.check(f"{worker_url}/health", token=admin_secret):
165-
print(f"Worker running in external process at {worker_url}", file=self.worker_log)
166+
if await self.check(f"{self.worker_url}/health", token=self.worker_secret):
167+
print(f"Worker running in external process at {self.worker_url}", file=self.worker_log)
166168
return
167169

168170
workerpy = str(self.dir / "pod" / "worker.py")
@@ -220,6 +222,17 @@ async def create_user(self, username: str) -> dict[str, Any]:
220222
raise Exception(result["error"])
221223
return result
222224

225+
async def set_worker_job_timeout(self, timeout: int):
226+
headers = {
227+
"Authorization": f"Bearer {self.worker_secret}",
228+
"Content-Type": "application/json",
229+
}
230+
async with aiohttp.ClientSession(headers=headers) as session:
231+
async with session.post(
232+
f"{self.worker_url}/configure", json={"job_timeout": timeout}
233+
) as response:
234+
response.raise_for_status()
235+
223236
def __enter__(self):
224237
self.loop.run(self.start())
225238
return self

tests/test_service.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import aiohttp
23
from pathlib import Path
34
from timeit import default_timer as timer
45
import pytest
@@ -57,16 +58,20 @@ async def runner():
5758
return results[0]
5859

5960

60-
def create_simple_workflow(prompt="fluffy ball", input: Image | None = None):
61+
def create_simple_workflow(prompt="fluffy ball", input: Image | Extent | None = None):
6162
start = 0
62-
images = ImageInput.from_extent(Extent(512, 512))
63-
if input:
63+
if isinstance(input, Image):
64+
images = ImageInput.from_extent(input.extent)
6465
images.initial_image = input
6566
images.hires_image = input
6667
start = 4
68+
elif isinstance(input, Extent):
69+
images = ImageInput.from_extent(input)
70+
else:
71+
images = ImageInput.from_extent(Extent(512, 512))
6772

6873
return WorkflowInput(
69-
WorkflowKind.generate if input is None else WorkflowKind.refine,
74+
WorkflowKind.generate if images.initial_image is None else WorkflowKind.refine,
7075
images=images,
7176
models=CheckpointInput("dreamshaper_8.safetensors"),
7277
sampling=SamplingInput(
@@ -230,3 +235,38 @@ async def main():
230235
end_time = timer()
231236
duration = end_time - start_time
232237
print(f"Completed 5 x 2 jobs in {duration:.2f} seconds", end=" ")
238+
239+
240+
def test_error_workflow(qtapp, cloud_client: CloudClient):
241+
workflow = create_simple_workflow()
242+
workflow.kind = WorkflowKind.refine # Error: refine requires an input image
243+
with pytest.raises(Exception, match="failed"):
244+
run_and_save(qtapp, cloud_client, workflow, "error_workflow")
245+
246+
247+
def test_job_timeout(pytestconfig, qtapp, cloud_service: CloudService):
248+
if not cloud_service.enabled:
249+
pytest.skip("Cloud service not running")
250+
251+
async def main():
252+
user = await cloud_service.create_user("timeout-tester")
253+
client = await CloudClient.connect(cloud_service.url, user["token"])
254+
big_workflow = create_simple_workflow(input=Extent(2048, 1536))
255+
256+
await cloud_service.set_worker_job_timeout(5)
257+
258+
with pytest.raises(Exception, match="timeout"):
259+
await receive_images(client, big_workflow)
260+
261+
# Worker should be restarted and accept new jobs
262+
for _attempt in range(5):
263+
try:
264+
await cloud_service.set_worker_job_timeout(480)
265+
break
266+
except aiohttp.ClientConnectionError:
267+
await asyncio.sleep(2) # Wait for worker to be back up
268+
small_workflow = create_simple_workflow()
269+
images = await receive_images(client, small_workflow)
270+
assert len(images) == 2
271+
272+
qtapp.run(main())

0 commit comments

Comments
 (0)