Skip to content

Commit 91d6414

Browse files
committed
feat: propagate MessageGroupId through job processing tasks
- Make process_job and shatter_job_rows bind=True and pass message_group_id through to child tasks via getattr(self, message_group_id, None) - Add message_group_id to process_job_row and save_* apply_async calls - process_incomplete_job passes service_id as message_group_id - Update tests: expect None when tasks run directly; add test for propagation
1 parent 88ba77f commit 91d6414

File tree

4 files changed

+272
-97
lines changed

4 files changed

+272
-97
lines changed

app/celery/tasks.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class ProcessReportRequestException(Exception):
6868
pass
6969

7070

71-
@notify_celery.task(name="process-job")
72-
def process_job(job_id, sender_id=None, shatter_batch_size=DEFAULT_SHATTER_JOB_ROWS_BATCH_SIZE):
71+
@notify_celery.task(bind=True, name="process-job")
72+
def process_job(self, job_id, sender_id=None, shatter_batch_size=DEFAULT_SHATTER_JOB_ROWS_BATCH_SIZE):
7373
start = datetime.utcnow()
7474
job = dao_get_job_by_id(job_id)
7575
current_app.logger.info(
@@ -116,19 +116,22 @@ def process_job(job_id, sender_id=None, shatter_batch_size=DEFAULT_SHATTER_JOB_R
116116
get_id_task_args_kwargs_for_job_row(row, template, job, service, sender_id=sender_id)[1]
117117
for row in shatter_batch
118118
]
119-
_shatter_job_rows_with_subdivision(template.template_type, batch_args_kwargs)
119+
_shatter_job_rows_with_subdivision(
120+
template.template_type, batch_args_kwargs, getattr(self, "message_group_id", None)
121+
)
120122

121123
job_complete(job, start=start)
122124

123125

124-
def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, top_level=True):
126+
def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, message_group_id, top_level=True):
125127
try:
126128
shatter_job_rows.apply_async(
127129
(
128130
template_type,
129131
args_kwargs_seq,
130132
),
131133
queue=QueueNames.JOBS,
134+
MessageGroupId=message_group_id,
132135
)
133136
except BotoClientError as e:
134137
# this information is helpfully not preserved outside the message string of the exception, so
@@ -146,7 +149,7 @@ def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, top_level
146149
raise UnprocessableJobRow from e
147150

148151
for sub_batch in (args_kwargs_seq[:split_batch_size], args_kwargs_seq[split_batch_size:]):
149-
_shatter_job_rows_with_subdivision(template_type, sub_batch, top_level=False)
152+
_shatter_job_rows_with_subdivision(template_type, sub_batch, message_group_id, top_level=False)
150153

151154
else:
152155
if not top_level:
@@ -156,16 +159,17 @@ def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, top_level
156159
)
157160

158161

159-
@notify_celery.task(name="shatter-job-rows")
162+
@notify_celery.task(bind=True, name="shatter-job-rows")
160163
def shatter_job_rows(
164+
self,
161165
template_type: str,
162166
args_kwargs_seq: Sequence,
163167
):
164168
for task_args_kwargs in args_kwargs_seq:
165-
process_job_row(template_type, task_args_kwargs)
169+
process_job_row(template_type, task_args_kwargs, getattr(self, "message_group_id", None))
166170

167171

168-
def process_job_row(template_type, task_args_kwargs):
172+
def process_job_row(template_type, task_args_kwargs, message_group_id=None):
169173
send_fn = {
170174
SMS_TYPE: save_sms,
171175
EMAIL_TYPE: save_email,
@@ -175,6 +179,7 @@ def process_job_row(template_type, task_args_kwargs):
175179
send_fn.apply_async(
176180
*task_args_kwargs,
177181
queue=QueueNames.DATABASE,
182+
MessageGroupId=message_group_id,
178183
)
179184

180185

@@ -353,6 +358,7 @@ def save_sms(
353358
provider_tasks.deliver_sms.apply_async(
354359
[str(saved_notification.id)],
355360
queue=QueueNames.SEND_SMS,
361+
MessageGroupId=getattr(self, "message_group_id", None),
356362
)
357363
else:
358364
extra = {
@@ -437,6 +443,7 @@ def save_email(self, service_id, notification_id, encoded_notification, sender_i
437443
provider_tasks.deliver_email.apply_async(
438444
[str(saved_notification.id)],
439445
queue=QueueNames.SEND_EMAIL,
446+
MessageGroupId=getattr(self, "message_group_id", None),
440447
)
441448

442449
extra = {
@@ -494,7 +501,9 @@ def save_letter(
494501
)
495502

496503
letters_pdf_tasks.get_pdf_for_templated_letter.apply_async(
497-
[str(saved_notification.id)], queue=QueueNames.CREATE_LETTERS_PDF
504+
[str(saved_notification.id)],
505+
queue=QueueNames.CREATE_LETTERS_PDF,
506+
MessageGroupId=getattr(self, "message_group_id", None),
498507
)
499508

500509
extra = {
@@ -529,7 +538,12 @@ def handle_exception(task, notification, notification_id, exc):
529538
# send to the retry queue.
530539
current_app.logger.exception("Retry: " + base_msg, extra, extra=extra) # noqa
531540
try:
532-
task.retry(queue=QueueNames.RETRY, exc=exc)
541+
retry_kwargs = {
542+
"queue": QueueNames.RETRY,
543+
"exc": exc,
544+
"MessageGroupId": getattr(task, "message_group_id", None),
545+
}
546+
task.retry(**retry_kwargs)
533547
except task.MaxRetriesExceededError:
534548
current_app.logger.error("Max retry failed: " + base_msg, extra, extra=extra) # noqa
535549

@@ -580,7 +594,7 @@ def process_incomplete_job(job_id, shatter_batch_size=DEFAULT_SHATTER_JOB_ROWS_B
580594
get_id_task_args_kwargs_for_job_row(row, template, job, job.service, sender_id=sender_id)[1]
581595
for row in shatter_batch
582596
]
583-
_shatter_job_rows_with_subdivision(template.template_type, batch_args_kwargs)
597+
_shatter_job_rows_with_subdivision(template.template_type, batch_args_kwargs, str(job.service_id))
584598

585599
job_complete(job, resumed=True)
586600

@@ -625,7 +639,11 @@ def _check_and_queue_returned_letter_callback_task(notification_id, service_id):
625639
# queue callback task only if the service_callback_api exists
626640
if service_callback_api := get_returned_letter_callback_api_for_service(service_id=service_id):
627641
returned_letter_data = create_returned_letter_callback_data(notification_id, service_id, service_callback_api)
628-
send_returned_letter_to_service.apply_async([returned_letter_data], queue=QueueNames.CALLBACKS)
642+
send_returned_letter_to_service.apply_async(
643+
[returned_letter_data],
644+
queue=QueueNames.CALLBACKS,
645+
MessageGroupId=str(service_id),
646+
)
629647

630648

631649
@notify_celery.task(bind=True, name="process-report-request")

tests/app/celery/test_scheduled_tasks.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,12 @@ def test_check_for_missing_rows_in_completed_jobs(mocker, sample_email_template,
724724
)
725725
]
726726
assert mock_save_email.mock_calls == [
727-
mock.call((str(job.service_id), "some-uuid", "something_encoded"), {}, queue="database-tasks")
727+
mock.call(
728+
(str(job.service_id), "some-uuid", "something_encoded"),
729+
{},
730+
queue="database-tasks",
731+
MessageGroupId=None,
732+
)
728733
]
729734

730735

@@ -765,7 +770,10 @@ def test_check_for_missing_rows_in_completed_jobs_uses_sender_id(
765770
]
766771
assert mock_save_email.mock_calls == [
767772
mock.call(
768-
(str(job.service_id), "some-uuid", "something_encoded"), {"sender_id": fake_uuid}, queue="database-tasks"
773+
(str(job.service_id), "some-uuid", "something_encoded"),
774+
{"sender_id": fake_uuid},
775+
queue="database-tasks",
776+
MessageGroupId=None,
769777
)
770778
]
771779

0 commit comments

Comments
 (0)