Skip to content

Commit 323cd83

Browse files
committed
publish_event is async
1 parent b99405f commit 323cd83

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

services/dask-sidecar/tests/unit/test_utils_dask.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
]
3939

4040

41-
def test_publish_event(
41+
async def test_publish_event(
4242
dask_client: distributed.Client, job_id: str, task_owner: TaskOwner
4343
):
4444
event_to_publish = TaskProgressEvent(
@@ -56,7 +56,7 @@ def handler(event: tuple) -> None:
5656

5757
dask_client.subscribe_topic(TaskProgressEvent.topic_name(), handler)
5858

59-
publish_event(dask_client, event=event_to_publish)
59+
await publish_event(dask_client, event=event_to_publish)
6060
for attempt in Retrying(
6161
wait=wait_fixed(0.2), stop=stop_after_delay(15), reraise=True
6262
):
@@ -127,33 +127,28 @@ async def test_publish_event_async_using_task(
127127
job_id: str,
128128
task_owner: TaskOwner,
129129
):
130-
dask_pub = distributed.Pub("some_topic", client=async_dask_client)
131-
dask_sub = distributed.Sub("some_topic", client=async_dask_client)
132130
NUMBER_OF_MESSAGES = 1000
133131
received_messages = []
134132

135-
async def _dask_sub_consumer_task(sub: distributed.Sub) -> None:
136-
print("--> starting consumer task")
137-
async for dask_event in sub:
138-
print(f"received {dask_event}")
139-
received_messages.append(dask_event)
140-
print("<-- finished consumer task")
133+
async def _consumer(event: tuple) -> None:
134+
print("received event", event)
135+
assert isinstance(event, tuple)
136+
received_messages.append(event)
141137

142-
consumer_task = asyncio_task(_dask_sub_consumer_task(dask_sub))
143-
assert consumer_task
138+
async_dask_client.subscribe_topic(TaskProgressEvent.topic_name(), _consumer)
144139

145-
async def _dask_publisher_task(pub: distributed.Pub) -> None:
140+
async def _dask_publisher_task(async_dask_client: distributed.Client) -> None:
146141
print("--> starting publisher task")
147142
for _ in range(NUMBER_OF_MESSAGES):
148143
event_to_publish = TaskProgressEvent(
149144
job_id=job_id,
150145
progress=0.5,
151146
task_owner=task_owner,
152147
)
153-
publish_event(dask_pub=pub, event=event_to_publish)
148+
await publish_event(async_dask_client, event=event_to_publish)
154149
print("<-- finished publisher task")
155150

156-
publisher_task = asyncio_task(_dask_publisher_task(dask_pub))
151+
publisher_task = asyncio_task(_dask_publisher_task(async_dask_client))
157152
assert publisher_task
158153

159154
async for attempt in AsyncRetrying(

0 commit comments

Comments
 (0)