Skip to content

Commit abcbfc5

Browse files
committed
Enable CUDA in managed server tests if available
1 parent 0c2c768 commit abcbfc5

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

tests/test_client.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ai_diffusion import eventloop, resources
66
from ai_diffusion.api import WorkflowInput, WorkflowKind, LoraInput
77
from ai_diffusion.api import CheckpointInput, ImageInput, SamplingInput, ConditioningInput
8+
from ai_diffusion.platform_tools import get_cuda_devices
89
from ai_diffusion.resources import ControlMode
910
from ai_diffusion.network import NetworkError
1011
from ai_diffusion.image import Extent
@@ -19,12 +20,17 @@
1920

2021
@pytest.fixture(scope="session")
2122
def comfy_server(qtapp):
22-
server = Server(str(server_dir), ServerBackend.cpu)
23+
backend = ServerBackend.cpu
24+
if len(get_cuda_devices()) > 0:
25+
backend = ServerBackend.cuda
26+
27+
server = Server(str(server_dir), backend)
2328
assert server.state is ServerState.stopped, (
2429
f"Expected server installation at {server_dir}. To create the default installation run"
2530
" `pytest tests/test_server.py --test-install`"
2631
)
27-
yield qtapp.run(server.start(port=8189))
32+
qtapp.run(server.start(port=8189))
33+
yield server
2834
qtapp.run(server.stop())
2935

3036

@@ -47,9 +53,10 @@ async def main():
4753

4854

4955
@pytest.mark.parametrize("cancel_point", ["after_enqueue", "after_start", "after_sampling"])
50-
def test_cancel(qtapp, comfy_server, cancel_point):
56+
def test_cancel(qtapp, comfy_server: Server, cancel_point):
5157
async def main():
52-
client = await ComfyClient.connect(comfy_server)
58+
assert comfy_server.url is not None
59+
client = await ComfyClient.connect(comfy_server.url)
5360
async for _ in client.discover_models(refresh=False):
5461
pass
5562
job_id = None
@@ -64,7 +71,7 @@ async def main():
6471
assert msg.event is not ClientEvent.finished
6572
assert msg.job_id == job_id or msg.job_id == ""
6673
if not job_id:
67-
job_id = await client.enqueue(make_default_work(steps=200))
74+
job_id = await client.enqueue(make_default_work(steps=1000))
6875
assert client.queued_count == 1
6976
if not interrupted:
7077
if cancel_point == "after_enqueue":
@@ -99,13 +106,14 @@ async def main():
99106
qtapp.run(main())
100107

101108

102-
def test_disconnect(qtapp, comfy_server):
109+
def test_disconnect(qtapp, comfy_server: Server):
103110
async def listen(client: ComfyClient):
104111
async for msg in client.listen():
105112
assert msg.event is ClientEvent.connected
106113

107114
async def main():
108-
client = await ComfyClient.connect(comfy_server)
115+
assert comfy_server.url is not None
116+
client = await ComfyClient.connect(comfy_server.url)
109117
task = eventloop._loop.create_task(listen(client))
110118
task.cancel()
111119
with pytest.raises(asyncio.CancelledError):
@@ -156,16 +164,23 @@ def check_resolve_sd_version(client: ComfyClient, arch: Arch):
156164
assert resolve_arch(style, None) == arch
157165

158166

159-
def test_info(pytestconfig, qtapp, comfy_server):
167+
def check_nunchaku(server: Server, client: ComfyClient):
168+
if server.backend is ServerBackend.cuda:
169+
assert "NunchakuFluxDiTLoader" in client.models.node_inputs.nodes
170+
171+
172+
def test_info(pytestconfig, qtapp, comfy_server: Server):
160173
async def main():
161-
client = await ComfyClient.connect(comfy_server)
174+
assert comfy_server.url is not None
175+
client = await ComfyClient.connect(comfy_server.url)
162176
async for _ in client.discover_models(refresh=False):
163177
pass
164178
check_client_info(client)
165179
await client.refresh()
166180
check_client_info(client)
167181
check_resolve_sd_version(client, Arch.sd15)
168182
# check_resolve_sd_version(client, Arch.sdxl) # no SDXL checkpoint in default installation
183+
check_nunchaku(comfy_server, client)
169184

170185
qtapp.run(main())
171186

tests/test_server.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88

99
from ai_diffusion import network, server, resources
10+
from ai_diffusion.platform_tools import get_cuda_devices
1011
from ai_diffusion.style import Arch
1112
from ai_diffusion.server import Server, ServerState, ServerBackend, InstallationProgress
1213
from ai_diffusion.server import model_dirs
@@ -48,6 +49,12 @@ def clear_test_server():
4849
server_dir.mkdir(exist_ok=True)
4950

5051

52+
def get_backend():
53+
if len(get_cuda_devices()) > 0:
54+
return ServerBackend.cuda
55+
return ServerBackend.cpu
56+
57+
5158
def test_install_and_run(qtapp, pytestconfig, local_download_server):
5259
"""Test installing and running ComfyUI server from scratch.
5360
* Takes a while, only runs with --test-install
@@ -62,7 +69,7 @@ def test_install_and_run(qtapp, pytestconfig, local_download_server):
6269

6370
clear_test_server()
6471

65-
server = Server(str(server_dir), ServerBackend.cpu)
72+
server = Server(str(server_dir), get_backend())
6673
assert server.state in [ServerState.not_installed, ServerState.missing_resources]
6774

6875
last_stage = ""
@@ -110,7 +117,7 @@ def test_run_external(qtapp, pytestconfig):
110117
if not (server_dir / "ComfyUI").exists():
111118
pytest.skip("ComfyUI installation not found")
112119

113-
server = Server(str(server_dir), ServerBackend.cpu)
120+
server = Server(str(server_dir), get_backend())
114121
assert server.has_python
115122
assert server.state in [ServerState.stopped, ServerState.missing_resources]
116123

@@ -129,7 +136,7 @@ def test_extra_model_dirs(pytestconfig):
129136
if not pytestconfig.getoption("--test-install"):
130137
pytest.skip("Only runs with --test-install")
131138

132-
server = Server(str(server_dir), ServerBackend.cpu)
139+
server = Server(str(server_dir), get_backend())
133140
# Requires test_install_and_run to setup the server
134141
assert server.state in [ServerState.stopped]
135142

@@ -141,7 +148,7 @@ def test_verify_and_fix(qtapp, pytestconfig, local_download_server):
141148
if not pytestconfig.getoption("--test-install"):
142149
pytest.skip("Only runs with --test-install")
143150

144-
server = Server(str(server_dir), ServerBackend.cpu)
151+
server = Server(str(server_dir), get_backend())
145152
# Requires test_install_and_run to setup the server
146153
assert server.state in [ServerState.stopped]
147154

@@ -179,7 +186,7 @@ def test_uninstall(qtapp, pytestconfig, local_download_server):
179186
temp_server_dir = server_dir.parent / "temp-server"
180187
if temp_server_dir.exists():
181188
shutil.rmtree(temp_server_dir, ignore_errors=True)
182-
server = Server(str(temp_server_dir), ServerBackend.cpu)
189+
server = Server(str(temp_server_dir), get_backend())
183190
assert server.state is ServerState.not_installed
184191

185192
def handle_progress(report: InstallationProgress):

0 commit comments

Comments
 (0)