1212
1313import pytest
1414from celery import Celery , Task # pylint: disable=no-name-in-module
15- from celery .worker .worker import WorkController # pylint: disable=no-name-in-module
1615from celery_library .errors import TaskNotFoundError , TransferrableCeleryError
1716from celery_library .task import register_task
1817from celery_library .task_manager import CeleryTaskManager
2322from servicelib .celery .models import (
2423 ExecutionMetadata ,
2524 OwnerMetadata ,
25+ TaskDataEvent ,
26+ TaskEventType ,
2627 TaskID ,
2728 TaskState ,
29+ TaskStatusEvent ,
30+ TaskStatusValue ,
2831 TaskUUID ,
2932 Wildcard ,
3033)
@@ -99,9 +102,9 @@ def _(celery_app: Celery) -> None:
99102 return _
100103
101104
105+ @pytest .mark .usefixtures ("with_celery_worker" )
102106async def test_submitting_task_calling_async_function_results_with_success_state (
103107 celery_task_manager : CeleryTaskManager ,
104- with_celery_worker : WorkController ,
105108):
106109
107110 owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
@@ -133,9 +136,9 @@ async def test_submitting_task_calling_async_function_results_with_success_state
133136 ) == "archive.zip"
134137
135138
139+ @pytest .mark .usefixtures ("with_celery_worker" )
136140async def test_submitting_task_with_failure_results_with_error (
137141 celery_task_manager : CeleryTaskManager ,
138- with_celery_worker : WorkController ,
139142):
140143
141144 owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
@@ -163,9 +166,9 @@ async def test_submitting_task_with_failure_results_with_error(
163166 assert f"{ raw_result } " == "Something strange happened: BOOM!"
164167
165168
169+ @pytest .mark .usefixtures ("with_celery_worker" )
166170async def test_cancelling_a_running_task_aborts_and_deletes (
167171 celery_task_manager : CeleryTaskManager ,
168- with_celery_worker : WorkController ,
169172):
170173
171174 owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
@@ -187,9 +190,9 @@ async def test_cancelling_a_running_task_aborts_and_deletes(
187190 assert task_uuid not in await celery_task_manager .list_tasks (owner_metadata )
188191
189192
193+ @pytest .mark .usefixtures ("with_celery_worker" )
190194async def test_listing_task_uuids_contains_submitted_task (
191195 celery_task_manager : CeleryTaskManager ,
192- with_celery_worker : WorkController ,
193196):
194197
195198 owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
@@ -214,9 +217,9 @@ async def test_listing_task_uuids_contains_submitted_task(
214217 assert any (task .uuid == task_uuid for task in tasks )
215218
216219
220+ @pytest .mark .usefixtures ("with_celery_worker" )
217221async def test_filtering_listing_tasks (
218222 celery_task_manager : CeleryTaskManager ,
219- with_celery_worker : WorkController ,
220223):
221224 class MyOwnerMetadata (OwnerMetadata ):
222225 user_id : int
@@ -266,3 +269,175 @@ class MyOwnerMetadata(OwnerMetadata):
266269 # clean up all tasks. this should ideally be done in the fixture
267270 for task_uuid , owner_metadata in all_tasks :
268271 await celery_task_manager .cancel_task (owner_metadata , task_uuid )
272+
273+
274+ @pytest .mark .usefixtures ("with_celery_worker" )
275+ async def test_publish_task_event_creates_data_event (
276+ celery_task_manager : CeleryTaskManager ,
277+ ):
278+ """Test that publishing a data event works correctly."""
279+ owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
280+
281+ # Create a task first
282+ task_uuid = await celery_task_manager .submit_task (
283+ ExecutionMetadata (
284+ name = fake_file_processor .__name__ ,
285+ ),
286+ owner_metadata = owner_metadata ,
287+ files = [f"file{ n } " for n in range (2 )],
288+ )
289+
290+ # Create and publish a data event
291+ task_id = owner_metadata .model_dump_task_id (task_uuid = task_uuid )
292+ data_event = TaskDataEvent (data = {"progress" : 0.5 , "message" : "Processing..." })
293+
294+ # This should not raise an exception
295+ await celery_task_manager .publish_task_event (task_id , data_event )
296+
297+ # Clean up
298+ await celery_task_manager .cancel_task (owner_metadata , task_uuid )
299+
300+
301+ @pytest .mark .usefixtures ("with_celery_worker" )
302+ async def test_publish_task_event_creates_status_event (
303+ celery_task_manager : CeleryTaskManager ,
304+ ):
305+ owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
306+
307+ task_uuid = await celery_task_manager .submit_task (
308+ ExecutionMetadata (
309+ name = fake_file_processor .__name__ ,
310+ ),
311+ owner_metadata = owner_metadata ,
312+ files = [f"file{ n } " for n in range (2 )],
313+ )
314+
315+ task_id = owner_metadata .model_dump_task_id (task_uuid = task_uuid )
316+ status_event = TaskStatusEvent (data = TaskStatusValue .SUCCESS )
317+
318+ await celery_task_manager .publish_task_event (task_id , status_event )
319+
320+ await celery_task_manager .cancel_task (owner_metadata , task_uuid )
321+
322+
323+ @pytest .mark .usefixtures ("with_celery_worker" )
324+ async def test_consume_task_events_reads_published_events (
325+ celery_task_manager : CeleryTaskManager ,
326+ ):
327+ owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
328+
329+ task_uuid = await celery_task_manager .submit_task (
330+ ExecutionMetadata (
331+ name = fake_file_processor .__name__ ,
332+ ),
333+ owner_metadata = owner_metadata ,
334+ files = [f"file{ n } " for n in range (2 )],
335+ )
336+
337+ task_id = owner_metadata .model_dump_task_id (task_uuid = task_uuid )
338+
339+ data_event = TaskDataEvent (data = {"progress" : 0.3 , "message" : "Starting..." })
340+ status_event = TaskStatusEvent (data = TaskStatusValue .SUCCESS )
341+
342+ await celery_task_manager .publish_task_event (task_id , data_event )
343+ await celery_task_manager .publish_task_event (task_id , status_event )
344+
345+ # Consume events
346+ events_received = []
347+ async for event_id , event in celery_task_manager .consume_task_events (
348+ owner_metadata = owner_metadata ,
349+ task_uuid = task_uuid ,
350+ ):
351+ events_received .append ((event_id , event ))
352+ if len (events_received ) >= 2 :
353+ break
354+
355+ assert len (events_received ) >= 1
356+
357+ data_events = [
358+ event for _ , event in events_received if event .type == TaskEventType .DATA
359+ ]
360+ status_events = [
361+ event for _ , event in events_received if event .type == TaskEventType .STATUS
362+ ]
363+
364+ assert len (data_events ) >= 1
365+ assert data_events [0 ].data == {"progress" : 0.3 , "message" : "Starting..." }
366+
367+ success_events = [
368+ event for event in status_events if event .data == TaskStatusValue .SUCCESS
369+ ]
370+ assert len (success_events ) >= 1
371+
372+ await celery_task_manager .cancel_task (owner_metadata , task_uuid )
373+
374+
375+ @pytest .mark .usefixtures ("with_celery_worker" )
376+ async def test_consume_task_events_with_last_id_filters_correctly (
377+ celery_task_manager : CeleryTaskManager ,
378+ ):
379+ """Test that consuming task events with last_id parameter works correctly."""
380+ owner_metadata = MyOwnerMetadata (user_id = 42 , owner = "test-owner" )
381+
382+ task_uuid = await celery_task_manager .submit_task (
383+ ExecutionMetadata (
384+ name = fake_file_processor .__name__ ,
385+ ),
386+ owner_metadata = owner_metadata ,
387+ files = [f"file{ n } " for n in range (2 )],
388+ )
389+
390+ task_id = owner_metadata .model_dump_task_id (task_uuid = task_uuid )
391+ first_event = TaskDataEvent (data = {"progress" : 0.1 , "message" : "First event" })
392+ await celery_task_manager .publish_task_event (task_id , first_event )
393+
394+ first_event_id = None
395+ async for event_id , event in celery_task_manager .consume_task_events (
396+ owner_metadata = owner_metadata ,
397+ task_uuid = task_uuid ,
398+ ):
399+ if (
400+ event .type == TaskEventType .DATA
401+ and event .data .get ("message" ) == "First event"
402+ ):
403+ first_event_id = event_id
404+ break
405+
406+ assert first_event_id is not None
407+
408+ second_event = TaskDataEvent (data = {"progress" : 0.5 , "message" : "Second event" })
409+ await celery_task_manager .publish_task_event (task_id , second_event )
410+
411+ events_after_first = []
412+ async for event_id , event in celery_task_manager .consume_task_events (
413+ owner_metadata = owner_metadata ,
414+ task_uuid = task_uuid ,
415+ last_id = first_event_id ,
416+ ):
417+ events_after_first .append ((event_id , event ))
418+ if (
419+ event .type == TaskEventType .DATA
420+ and event .data .get ("message" ) == "Second event"
421+ ):
422+ break
423+
424+ assert len (events_after_first ) >= 1
425+ data_events_after = [
426+ event for _ , event in events_after_first if event .type == TaskEventType .DATA
427+ ]
428+
429+ first_event_messages = [
430+ event .data .get ("message" )
431+ for event in data_events_after
432+ if event .data .get ("message" ) == "First event"
433+ ]
434+ assert len (first_event_messages ) == 0
435+
436+ second_event_messages = [
437+ event .data .get ("message" )
438+ for event in data_events_after
439+ if event .data .get ("message" ) == "Second event"
440+ ]
441+ assert len (second_event_messages ) >= 1
442+
443+ await celery_task_manager .cancel_task (owner_metadata , task_uuid )
0 commit comments