Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions apps/worker/celery_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import logging
from datetime import timedelta

import sentry_sdk
from celery import signals
from celery.beat import BeatLazyFunc
from celery.schedules import crontab

from celery_task_router import route_task
from helpers.clock import get_utc_now_as_iso_format
from helpers.health_check import get_health_check_interval_seconds
from helpers.sentry import initialize_sentry, is_sentry_enabled
from shared.celery_config import (
BaseCeleryConfig,
brolly_stats_rollup_task_name,
Expand Down Expand Up @@ -51,6 +53,47 @@ def initialize_cache(**kwargs):
cache.configure(redis_cache_backend)


@signals.worker_process_init.connect
def initialize_sentry_for_worker_process(**kwargs):
"""Initialize Sentry in each forked worker process.

Celery workers fork from the main process, so Sentry must be re-initialized
in each worker to ensure proper error tracking and trace propagation.
"""
if is_sentry_enabled():
initialize_sentry()


@signals.worker_shutting_down.connect
def flush_sentry_on_shutdown(**kwargs):
"""Flush pending Sentry events before worker shutdown.

This ensures crash data is sent before the process terminates, which is
especially important for graceful shutdowns (SIGTERM before SIGKILL).
"""
client = sentry_sdk.get_client()
if client.is_active():
client.flush(timeout=10)


@signals.task_prerun.connect
def set_sentry_task_context(task_id, task, args, kwargs, **kw):
"""Set Sentry context at the start of each task execution.

This adds task args/kwargs to Sentry events for debugging. Note that
task_name and task_id tags are already set by LogContext.add_to_sentry().
"""
sentry_sdk.set_context(
"celery_task",
{
"task_name": task.name,
"task_id": task_id,
"args": str(args)[:500] if args else None,
"kwargs": str(kwargs)[:500] if kwargs else None,
},
)


# Cron task names
hourly_check_task_name = "app.cron.hourly_check.HourlyCheckTask"
daily_plan_manager_task_name = "app.cron.daily.PlanManagerTask"
Expand Down
59 changes: 45 additions & 14 deletions apps/worker/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def on_timeout(self, soft: bool, timeout: int):
res = super().on_timeout(soft, timeout)
if not soft:
REQUEST_HARD_TIMEOUT_COUNTER.labels(task=self.name).inc()
self._capture_hard_timeout_to_sentry(timeout)

REQUEST_TIMEOUT_COUNTER.labels(task=self.name).inc()

if UploadFlow.has_begun():
Expand All @@ -90,6 +92,28 @@ def on_timeout(self, soft: bool, timeout: int):

return res

def _capture_hard_timeout_to_sentry(self, timeout: int):
"""Capture hard timeout to Sentry with immediate flush.

Hard timeouts result in SIGKILL, so we must flush immediately
to ensure the event is sent before the process is killed.
"""
scope = sentry_sdk.get_current_scope()
scope.set_tag("failure_type", "hard_timeout")
scope.set_context(
"timeout",
{
"task_name": self.name,
"task_id": self.id,
"timeout_seconds": timeout,
},
)
sentry_sdk.capture_message(
f"Hard timeout in task {self.name} after {timeout}s",
level="error",
)
sentry_sdk.flush(timeout=2)


# Task reliability metrics
TASK_RUN_COUNTER = Counter(
Expand Down Expand Up @@ -446,42 +470,28 @@ def run(self, *args, **kwargs):
db_session.rollback()
retry_count = getattr(self.request, "retries", 0)
countdown = TASK_RETRY_BACKOFF_BASE_SECONDS * (2**retry_count)
# Use safe_retry to handle max retries exceeded gracefully
# Returns False if max retries exceeded, otherwise raises Retry
# Wrap in try-except to catch MaxRetriesExceededError if it escapes safe_retry
# (exceptions raised inside except blocks aren't caught by sibling except clauses)
try:
if not self.safe_retry(countdown=countdown):
# Max retries exceeded - return None to match old behavior
return None
except MaxRetriesExceededError:
# Handle MaxRetriesExceededError if it escapes safe_retry
if UploadFlow.has_begun():
UploadFlow.log(UploadFlow.UNCAUGHT_RETRY_EXCEPTION)
if TestResultsFlow.has_begun():
TestResultsFlow.log(TestResultsFlow.UNCAUGHT_RETRY_EXCEPTION)
# Return None to match old behavior
return None
except SQLAlchemyError as ex:
self._analyse_error(ex, args, kwargs)
db_session.rollback()
retry_count = getattr(self.request, "retries", 0)
countdown = TASK_RETRY_BACKOFF_BASE_SECONDS * (2**retry_count)
# Use safe_retry to handle max retries exceeded gracefully
# Returns False if max retries exceeded, otherwise raises Retry
# Wrap in try-except to catch MaxRetriesExceededError if it escapes safe_retry
# (exceptions raised inside except blocks aren't caught by sibling except clauses)
try:
if not self.safe_retry(countdown=countdown):
# Max retries exceeded - return None to match old behavior
return None
except MaxRetriesExceededError:
# Handle MaxRetriesExceededError if it escapes safe_retry
if UploadFlow.has_begun():
UploadFlow.log(UploadFlow.UNCAUGHT_RETRY_EXCEPTION)
if TestResultsFlow.has_begun():
TestResultsFlow.log(TestResultsFlow.UNCAUGHT_RETRY_EXCEPTION)
# Return None to match old behavior
return None
except MaxRetriesExceededError as ex:
if UploadFlow.has_begun():
Expand Down Expand Up @@ -541,6 +551,8 @@ def on_success(self, retval, task_id, args, kwargs):

def on_failure(self, exc, task_id, args, kwargs, einfo):
"""Handle task failure, logging flow checkpoints and incrementing metrics."""
self._capture_failure_to_sentry(exc, task_id, args, kwargs)

res = super().on_failure(exc, task_id, args, kwargs, einfo)
self.task_failure_counter.inc()

Expand All @@ -551,6 +563,25 @@ def on_failure(self, exc, task_id, args, kwargs, einfo):

return res

def _capture_failure_to_sentry(self, exc, task_id, args, kwargs):
"""Capture task failure to Sentry with rich context."""
scope = sentry_sdk.get_current_scope()
scope.set_tag("failure_type", "task_failure")
scope.set_context(
"celery",
{
"task_name": self.name,
"task_id": task_id,
"args": str(args)[:500] if args else None,
"kwargs": (
{k: str(v)[:100] for k, v in kwargs.items()} if kwargs else None
),
"retries": getattr(self.request, "retries", 0),
"attempts": self.attempts,
},
)
sentry_sdk.capture_exception(exc)

def get_repo_provider_service(
self,
repository: Repository,
Expand Down
Loading