55from ai_diffusion import eventloop , resources
66from ai_diffusion .api import WorkflowInput , WorkflowKind , LoraInput
77from ai_diffusion .api import CheckpointInput , ImageInput , SamplingInput , ConditioningInput
8+ from ai_diffusion .platform_tools import get_cuda_devices
89from ai_diffusion .resources import ControlMode
910from ai_diffusion .network import NetworkError
1011from ai_diffusion .image import Extent
1920
2021@pytest .fixture (scope = "session" )
2122def comfy_server (qtapp ):
22- server = Server (str (server_dir ), ServerBackend .cpu )
23+ backend = ServerBackend .cpu
24+ if len (get_cuda_devices ()) > 0 :
25+ backend = ServerBackend .cuda
26+
27+ server = Server (str (server_dir ), backend )
2328 assert server .state is ServerState .stopped , (
2429 f"Expected server installation at { server_dir } . To create the default installation run"
2530 " `pytest tests/test_server.py --test-install`"
2631 )
27- yield qtapp .run (server .start (port = 8189 ))
32+ qtapp .run (server .start (port = 8189 ))
33+ yield server
2834 qtapp .run (server .stop ())
2935
3036
@@ -47,9 +53,10 @@ async def main():
4753
4854
4955@pytest .mark .parametrize ("cancel_point" , ["after_enqueue" , "after_start" , "after_sampling" ])
50- def test_cancel (qtapp , comfy_server , cancel_point ):
56+ def test_cancel (qtapp , comfy_server : Server , cancel_point ):
5157 async def main ():
52- client = await ComfyClient .connect (comfy_server )
58+ assert comfy_server .url is not None
59+ client = await ComfyClient .connect (comfy_server .url )
5360 async for _ in client .discover_models (refresh = False ):
5461 pass
5562 job_id = None
@@ -64,7 +71,7 @@ async def main():
6471 assert msg .event is not ClientEvent .finished
6572 assert msg .job_id == job_id or msg .job_id == ""
6673 if not job_id :
67- job_id = await client .enqueue (make_default_work (steps = 200 ))
74+ job_id = await client .enqueue (make_default_work (steps = 1000 ))
6875 assert client .queued_count == 1
6976 if not interrupted :
7077 if cancel_point == "after_enqueue" :
@@ -99,13 +106,14 @@ async def main():
99106 qtapp .run (main ())
100107
101108
102- def test_disconnect (qtapp , comfy_server ):
109+ def test_disconnect (qtapp , comfy_server : Server ):
103110 async def listen (client : ComfyClient ):
104111 async for msg in client .listen ():
105112 assert msg .event is ClientEvent .connected
106113
107114 async def main ():
108- client = await ComfyClient .connect (comfy_server )
115+ assert comfy_server .url is not None
116+ client = await ComfyClient .connect (comfy_server .url )
109117 task = eventloop ._loop .create_task (listen (client ))
110118 task .cancel ()
111119 with pytest .raises (asyncio .CancelledError ):
@@ -156,16 +164,23 @@ def check_resolve_sd_version(client: ComfyClient, arch: Arch):
156164 assert resolve_arch (style , None ) == arch
157165
158166
159- def test_info (pytestconfig , qtapp , comfy_server ):
167+ def check_nunchaku (server : Server , client : ComfyClient ):
168+ if server .backend is ServerBackend .cuda :
169+ assert "NunchakuFluxDiTLoader" in client .models .node_inputs .nodes
170+
171+
172+ def test_info (pytestconfig , qtapp , comfy_server : Server ):
160173 async def main ():
161- client = await ComfyClient .connect (comfy_server )
174+ assert comfy_server .url is not None
175+ client = await ComfyClient .connect (comfy_server .url )
162176 async for _ in client .discover_models (refresh = False ):
163177 pass
164178 check_client_info (client )
165179 await client .refresh ()
166180 check_client_info (client )
167181 check_resolve_sd_version (client , Arch .sd15 )
168182 # check_resolve_sd_version(client, Arch.sdxl) # no SDXL checkpoint in default installation
183+ check_nunchaku (comfy_server , client )
169184
170185 qtapp .run (main ())
171186
0 commit comments