Skip to content

Commit ea70df8

Browse files
committed
feat: propagate MessageGroupId through job processing tasks
- Make process_job and shatter_job_rows bind=True and pass message_group_id through to child tasks via getattr(self, message_group_id, None) - Add message_group_id to process_job_row and save_* apply_async calls - process_incomplete_job passes service_id as message_group_id - Update tests: expect None when tasks run directly; add test for propagation
1 parent 88ba77f commit ea70df8

File tree

4 files changed

+266
-97
lines changed

4 files changed

+266
-97
lines changed

app/celery/tasks.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class ProcessReportRequestException(Exception):
6868
pass
6969

7070

71-
@notify_celery.task(name="process-job")
72-
def process_job(job_id, sender_id=None, shatter_batch_size=DEFAULT_SHATTER_JOB_ROWS_BATCH_SIZE):
71+
@notify_celery.task(bind=True, name="process-job")
72+
def process_job(self, job_id, sender_id=None, shatter_batch_size=DEFAULT_SHATTER_JOB_ROWS_BATCH_SIZE):
7373
start = datetime.utcnow()
7474
job = dao_get_job_by_id(job_id)
7575
current_app.logger.info(
@@ -116,19 +116,20 @@ def process_job(job_id, sender_id=None, shatter_batch_size=DEFAULT_SHATTER_JOB_R
116116
get_id_task_args_kwargs_for_job_row(row, template, job, service, sender_id=sender_id)[1]
117117
for row in shatter_batch
118118
]
119-
_shatter_job_rows_with_subdivision(template.template_type, batch_args_kwargs)
119+
_shatter_job_rows_with_subdivision(template.template_type, batch_args_kwargs, self.message_group_id)
120120

121121
job_complete(job, start=start)
122122

123123

124-
def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, top_level=True):
124+
def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, message_group_id, top_level=True):
125125
try:
126126
shatter_job_rows.apply_async(
127127
(
128128
template_type,
129129
args_kwargs_seq,
130130
),
131131
queue=QueueNames.JOBS,
132+
MessageGroupId=message_group_id,
132133
)
133134
except BotoClientError as e:
134135
# this information is helpfully not preserved outside the message string of the exception, so
@@ -146,7 +147,7 @@ def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, top_level
146147
raise UnprocessableJobRow from e
147148

148149
for sub_batch in (args_kwargs_seq[:split_batch_size], args_kwargs_seq[split_batch_size:]):
149-
_shatter_job_rows_with_subdivision(template_type, sub_batch, top_level=False)
150+
_shatter_job_rows_with_subdivision(template_type, sub_batch, message_group_id, top_level=False)
150151

151152
else:
152153
if not top_level:
@@ -156,16 +157,17 @@ def _shatter_job_rows_with_subdivision(template_type, args_kwargs_seq, top_level
156157
)
157158

158159

159-
@notify_celery.task(name="shatter-job-rows")
160+
@notify_celery.task(bind=True, name="shatter-job-rows")
160161
def shatter_job_rows(
162+
self,
161163
template_type: str,
162164
args_kwargs_seq: Sequence,
163165
):
164166
for task_args_kwargs in args_kwargs_seq:
165-
process_job_row(template_type, task_args_kwargs)
167+
process_job_row(template_type, task_args_kwargs, self.message_group_id)
166168

167169

168-
def process_job_row(template_type, task_args_kwargs):
170+
def process_job_row(template_type, task_args_kwargs, message_group_id=None):
169171
send_fn = {
170172
SMS_TYPE: save_sms,
171173
EMAIL_TYPE: save_email,
@@ -175,6 +177,7 @@ def process_job_row(template_type, task_args_kwargs):
175177
send_fn.apply_async(
176178
*task_args_kwargs,
177179
queue=QueueNames.DATABASE,
180+
MessageGroupId=message_group_id,
178181
)
179182

180183

@@ -353,6 +356,7 @@ def save_sms(
353356
provider_tasks.deliver_sms.apply_async(
354357
[str(saved_notification.id)],
355358
queue=QueueNames.SEND_SMS,
359+
MessageGroupId=self.message_group_id,
356360
)
357361
else:
358362
extra = {
@@ -437,6 +441,7 @@ def save_email(self, service_id, notification_id, encoded_notification, sender_i
437441
provider_tasks.deliver_email.apply_async(
438442
[str(saved_notification.id)],
439443
queue=QueueNames.SEND_EMAIL,
444+
MessageGroupId=self.message_group_id,
440445
)
441446

442447
extra = {
@@ -494,7 +499,9 @@ def save_letter(
494499
)
495500

496501
letters_pdf_tasks.get_pdf_for_templated_letter.apply_async(
497-
[str(saved_notification.id)], queue=QueueNames.CREATE_LETTERS_PDF
502+
[str(saved_notification.id)],
503+
queue=QueueNames.CREATE_LETTERS_PDF,
504+
MessageGroupId=self.message_group_id,
498505
)
499506

500507
extra = {
@@ -529,7 +536,12 @@ def handle_exception(task, notification, notification_id, exc):
529536
# send to the retry queue.
530537
current_app.logger.exception("Retry: " + base_msg, extra, extra=extra) # noqa
531538
try:
532-
task.retry(queue=QueueNames.RETRY, exc=exc)
539+
retry_kwargs = {
540+
"queue": QueueNames.RETRY,
541+
"exc": exc,
542+
"MessageGroupId": getattr(task, "message_group_id", None),
543+
}
544+
task.retry(**retry_kwargs)
533545
except task.MaxRetriesExceededError:
534546
current_app.logger.error("Max retry failed: " + base_msg, extra, extra=extra) # noqa
535547

@@ -580,7 +592,7 @@ def process_incomplete_job(job_id, shatter_batch_size=DEFAULT_SHATTER_JOB_ROWS_B
580592
get_id_task_args_kwargs_for_job_row(row, template, job, job.service, sender_id=sender_id)[1]
581593
for row in shatter_batch
582594
]
583-
_shatter_job_rows_with_subdivision(template.template_type, batch_args_kwargs)
595+
_shatter_job_rows_with_subdivision(template.template_type, batch_args_kwargs, str(job.service_id))
584596

585597
job_complete(job, resumed=True)
586598

@@ -625,7 +637,11 @@ def _check_and_queue_returned_letter_callback_task(notification_id, service_id):
625637
# queue callback task only if the service_callback_api exists
626638
if service_callback_api := get_returned_letter_callback_api_for_service(service_id=service_id):
627639
returned_letter_data = create_returned_letter_callback_data(notification_id, service_id, service_callback_api)
628-
send_returned_letter_to_service.apply_async([returned_letter_data], queue=QueueNames.CALLBACKS)
640+
send_returned_letter_to_service.apply_async(
641+
[returned_letter_data],
642+
queue=QueueNames.CALLBACKS,
643+
MessageGroupId=str(service_id),
644+
)
629645

630646

631647
@notify_celery.task(bind=True, name="process-report-request")

tests/app/celery/test_scheduled_tasks.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,12 @@ def test_check_for_missing_rows_in_completed_jobs(mocker, sample_email_template,
724724
)
725725
]
726726
assert mock_save_email.mock_calls == [
727-
mock.call((str(job.service_id), "some-uuid", "something_encoded"), {}, queue="database-tasks")
727+
mock.call(
728+
(str(job.service_id), "some-uuid", "something_encoded"),
729+
{},
730+
queue="database-tasks",
731+
MessageGroupId=None,
732+
)
728733
]
729734

730735

@@ -765,7 +770,10 @@ def test_check_for_missing_rows_in_completed_jobs_uses_sender_id(
765770
]
766771
assert mock_save_email.mock_calls == [
767772
mock.call(
768-
(str(job.service_id), "some-uuid", "something_encoded"), {"sender_id": fake_uuid}, queue="database-tasks"
773+
(str(job.service_id), "some-uuid", "something_encoded"),
774+
{"sender_id": fake_uuid},
775+
queue="database-tasks",
776+
MessageGroupId=None,
769777
)
770778
]
771779

0 commit comments

Comments
 (0)