Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev-frozen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ sentry-forked-django-stubs==5.1.1.post1
sentry-forked-djangorestframework-stubs==3.15.1.post2
sentry-kafka-schemas==0.1.122
sentry-ophio==1.0.0
sentry-protos==0.1.37
sentry-protos==0.1.39
sentry-redis-tools==0.1.7
sentry-relay==0.9.3
sentry-sdk==2.19.2
Expand Down
2 changes: 1 addition & 1 deletion requirements-frozen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ s3transfer==0.10.0
sentry-arroyo==2.18.2
sentry-kafka-schemas==0.1.122
sentry-ophio==1.0.0
sentry-protos==0.1.37
sentry-protos==0.1.39
sentry-redis-tools==0.1.7
sentry-relay==0.9.3
sentry-sdk==2.19.2
Expand Down
9 changes: 7 additions & 2 deletions src/sentry/runner/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,18 @@ def worker(ignore_unknown_queues: bool, **options: Any) -> None:
@click.option(
"--max-task-count", help="Number of tasks this worker should run before exiting", default=10000
)
@click.option(
"--namespace", help="The dedicated task namespace that taskworker operates on", default=None
)
@log_options()
@configuration
def taskworker(rpc_host: str, max_task_count: int, **options: Any) -> None:
def taskworker(rpc_host: str, max_task_count: int, namespace: str | None, **options: Any) -> None:
from sentry.taskworker.worker import TaskWorker

with managed_bgtasks(role="taskworker"):
worker = TaskWorker(rpc_host=rpc_host, max_task_count=max_task_count, **options)
worker = TaskWorker(
rpc_host=rpc_host, max_task_count=max_task_count, namespace=namespace, **options
)
exitcode = worker.start()
raise SystemExit(exitcode)

Expand Down
14 changes: 8 additions & 6 deletions src/sentry/taskworker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import grpc
from sentry_protos.sentry.v1.taskworker_pb2 import (
FetchNextTask,
GetTaskRequest,
SetTaskStatusRequest,
TaskActivation,
Expand All @@ -24,13 +25,14 @@ def __init__(self, host: str) -> None:
self._channel = grpc.insecure_channel(self._host)
self._stub = ConsumerServiceStub(self._channel)

def get_task(self) -> TaskActivation | None:
def get_task(self, namespace: str | None = None) -> TaskActivation | None:
"""
Fetch a pending task
Fetch a pending task.

Will return None when there are no tasks to fetch
If a namespace is provided, only tasks for that namespace will be fetched.
This will return None if there are no tasks to fetch.
"""
request = GetTaskRequest()
request = GetTaskRequest(namespace=namespace)
try:
response = self._stub.GetTask(request)
except grpc.RpcError as err:
Expand All @@ -42,7 +44,7 @@ def get_task(self) -> TaskActivation | None:
return None

def update_task(
self, task_id: str, status: TaskActivationStatus.ValueType, fetch_next: bool = True
self, task_id: str, status: TaskActivationStatus.ValueType, fetch_next_task: FetchNextTask
) -> TaskActivation | None:
"""
Update the status for a given task activation.
Expand All @@ -52,7 +54,7 @@ def update_task(
request = SetTaskStatusRequest(
id=task_id,
status=status,
fetch_next=fetch_next,
fetch_next_task=fetch_next_task,
)
try:
response = self._stub.SetTaskStatus(request)
Expand Down
12 changes: 10 additions & 2 deletions src/sentry/taskworker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TASK_ACTIVATION_STATUS_COMPLETE,
TASK_ACTIVATION_STATUS_FAILURE,
TASK_ACTIVATION_STATUS_RETRY,
FetchNextTask,
TaskActivation,
)

Expand Down Expand Up @@ -57,12 +58,17 @@ class TaskWorker:
"""

def __init__(
self, rpc_host: str, max_task_count: int | None = None, **options: dict[str, Any]
self,
rpc_host: str,
max_task_count: int | None = None,
namespace: str | None = None,
**options: dict[str, Any],
) -> None:
self.options = options
self._execution_count = 0
self._worker_id = uuid4().hex
self._max_task_count = max_task_count
self._namespace = namespace
self.client = TaskworkerClient(rpc_host)
self._pool: Pool | None = None
self._build_pool()
Expand Down Expand Up @@ -124,7 +130,7 @@ def start(self) -> int:

def fetch_task(self) -> TaskActivation | None:
try:
activation = self.client.get_task()
activation = self.client.get_task(self._namespace)
except grpc.RpcError:
metrics.incr("taskworker.worker.get_task.failed")
logger.info("get_task failed. Retrying in 1 second")
Expand Down Expand Up @@ -167,6 +173,7 @@ def process_task(self, activation: TaskActivation) -> TaskActivation | None:
return self.client.update_task(
task_id=activation.id,
status=TASK_ACTIVATION_STATUS_FAILURE,
fetch_next_task=FetchNextTask(namespace=self._namespace),
)

if task.at_most_once:
Expand Down Expand Up @@ -260,4 +267,5 @@ def process_task(self, activation: TaskActivation) -> TaskActivation | None:
return self.client.update_task(
task_id=activation.id,
status=next_state,
fetch_next_task=FetchNextTask(namespace=self._namespace),
)
64 changes: 61 additions & 3 deletions tests/sentry/taskworker/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from google.protobuf.message import Message
from sentry_protos.sentry.v1.taskworker_pb2 import (
TASK_ACTIVATION_STATUS_RETRY,
FetchNextTask,
GetTaskResponse,
SetTaskStatusResponse,
TaskActivation,
Expand Down Expand Up @@ -97,6 +98,31 @@ def test_get_task_ok():
assert result.namespace == "testing"


def test_get_task_with_namespace():
channel = MockChannel()
channel.add_response(
"/sentry_protos.sentry.v1.ConsumerService/GetTask",
GetTaskResponse(
task=TaskActivation(
id="abc123",
namespace="testing",
taskname="do_thing",
parameters="",
headers={},
processing_deadline_duration=10,
)
),
)
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.get_task(namespace="testing")

assert result
assert result.id
assert result.namespace == "testing"


def test_get_task_not_found():
channel = MockChannel()
channel.add_response(
Expand Down Expand Up @@ -142,11 +168,39 @@ def test_update_task_ok_with_next():
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task("abc123", TASK_ACTIVATION_STATUS_RETRY)
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace=None)
)
assert result
assert result.id == "abc123"


def test_update_task_ok_with_next_namespace():
channel = MockChannel()
channel.add_response(
"/sentry_protos.sentry.v1.ConsumerService/SetTaskStatus",
SetTaskStatusResponse(
task=TaskActivation(
id="abc123",
namespace="testing",
taskname="do_thing",
parameters="",
headers={},
processing_deadline_duration=10,
)
),
)
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace="testing")
)
assert result
assert result.id == "abc123"
assert result.namespace == "testing"


def test_update_task_ok_no_next():
channel = MockChannel()
channel.add_response(
Expand All @@ -155,7 +209,9 @@ def test_update_task_ok_no_next():
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task("abc123", TASK_ACTIVATION_STATUS_RETRY)
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace=None)
)
assert result is None


Expand All @@ -168,5 +224,7 @@ def test_update_task_not_found():
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task("abc123", TASK_ACTIVATION_STATUS_RETRY)
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace=None)
)
assert result is None
33 changes: 25 additions & 8 deletions tests/sentry/taskworker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TASK_ACTIVATION_STATUS_COMPLETE,
TASK_ACTIVATION_STATUS_FAILURE,
TASK_ACTIVATION_STATUS_RETRY,
FetchNextTask,
TaskActivation,
)

Expand Down Expand Up @@ -110,7 +111,9 @@ def test_process_task_complete(self) -> None:
result = taskworker.process_task(SIMPLE_TASK)

mock_update.assert_called_with(
task_id=SIMPLE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=SIMPLE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)

assert result
Expand All @@ -123,7 +126,9 @@ def test_process_task_retry(self) -> None:
result = taskworker.process_task(RETRY_TASK)

mock_update.assert_called_with(
task_id=RETRY_TASK.id, status=TASK_ACTIVATION_STATUS_RETRY
task_id=RETRY_TASK.id,
status=TASK_ACTIVATION_STATUS_RETRY,
fetch_next_task=FetchNextTask(namespace=None),
)

assert result
Expand All @@ -136,7 +141,9 @@ def test_process_task_failure(self) -> None:
result = taskworker.process_task(FAIL_TASK)

mock_update.assert_called_with(
task_id=FAIL_TASK.id, status=TASK_ACTIVATION_STATUS_FAILURE
task_id=FAIL_TASK.id,
status=TASK_ACTIVATION_STATUS_FAILURE,
fetch_next_task=FetchNextTask(namespace=None),
)
assert result
assert result.id == SIMPLE_TASK.id
Expand All @@ -148,7 +155,9 @@ def test_process_task_at_most_once(self) -> None:
result = taskworker.process_task(AT_MOST_ONCE_TASK)

mock_update.assert_called_with(
task_id=AT_MOST_ONCE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=AT_MOST_ONCE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)
assert taskworker.process_task(AT_MOST_ONCE_TASK) is None
assert result
Expand All @@ -169,7 +178,9 @@ def test_start_max_task_count(self) -> None:
assert result == 0
assert mock_client.get_task.called
mock_client.update_task.assert_called_with(
task_id=SIMPLE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=SIMPLE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)

def test_start_loop(self) -> None:
Expand All @@ -188,10 +199,14 @@ def test_start_loop(self) -> None:
assert mock_client.update_task.call_count == 2

mock_client.update_task.assert_any_call(
task_id=SIMPLE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=SIMPLE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)
mock_client.update_task.assert_any_call(
task_id=RETRY_TASK.id, status=TASK_ACTIVATION_STATUS_RETRY
task_id=RETRY_TASK.id,
status=TASK_ACTIVATION_STATUS_RETRY,
fetch_next_task=FetchNextTask(namespace=None),
)

def test_start_keyboard_interrupt(self) -> None:
Expand All @@ -210,5 +225,7 @@ def test_start_unknown_task(self) -> None:
result = taskworker.start()
assert result == 0, "Exit zero, all tasks complete"
mock_client.update_task.assert_any_call(
task_id=UNDEFINED_TASK.id, status=TASK_ACTIVATION_STATUS_FAILURE
task_id=UNDEFINED_TASK.id,
status=TASK_ACTIVATION_STATUS_FAILURE,
fetch_next_task=FetchNextTask(namespace=None),
)
Loading