11from concurrent .futures import Future
2- from unittest .mock import patch
2+ from unittest .mock import Mock
33
44import pytest
55from django .test .utils import override_settings
1111from sentry .conf .types .kafka_definition import Topic
1212from sentry .taskworker .registry import TaskNamespace , TaskRegistry
1313from sentry .taskworker .retry import LastAction , Retry
14+ from sentry .taskworker .router import DefaultRouter
1415from sentry .taskworker .task import Task
1516
1617
1718def test_namespace_register_task () -> None :
1819 namespace = TaskNamespace (
1920 name = "tests" ,
20- topic = Topic . TASK_WORKER ,
21+ router = DefaultRouter () ,
2122 retry = None ,
2223 )
2324
@@ -37,7 +38,7 @@ def simple_task():
3738def test_namespace_register_inherits_default_retry () -> None :
3839 namespace = TaskNamespace (
3940 name = "tests" ,
40- topic = Topic . TASK_WORKER ,
41+ router = DefaultRouter () ,
4142 retry = Retry (times = 5 , on = (RuntimeError ,)),
4243 )
4344
@@ -65,7 +66,7 @@ def retry_none_param() -> None:
6566def test_register_inherits_default_expires_processing_deadline () -> None :
6667 namespace = TaskNamespace (
6768 name = "tests" ,
68- topic = Topic . TASK_WORKER ,
69+ router = DefaultRouter () ,
6970 retry = None ,
7071 expires = 10 * 60 ,
7172 processing_deadline_duration = 5 ,
@@ -93,7 +94,7 @@ def with_expires() -> None:
9394def test_namespace_get_unknown () -> None :
9495 namespace = TaskNamespace (
9596 name = "tests" ,
96- topic = Topic . TASK_WORKER ,
97+ router = DefaultRouter () ,
9798 retry = None ,
9899 )
99100
@@ -102,10 +103,11 @@ def test_namespace_get_unknown() -> None:
102103 assert "No task registered" in str (err )
103104
104105
106+ @pytest .mark .django_db
105107def test_namespace_send_task_no_retry () -> None :
106108 namespace = TaskNamespace (
107109 name = "tests" ,
108- topic = Topic . TASK_WORKER ,
110+ router = DefaultRouter () ,
109111 retry = None ,
110112 )
111113
@@ -118,21 +120,24 @@ def simple_task() -> None:
118120 assert activation .retry_state .max_attempts == 1
119121 assert activation .retry_state .on_attempts_exceeded == ON_ATTEMPTS_EXCEEDED_DISCARD
120122
121- with patch .object (namespace , "_producer" ) as mock_producer :
122- namespace .send_task (activation )
123- assert mock_producer .produce .call_count == 1
123+ mock_producer = Mock ()
124+ namespace ._producers [Topic .TASK_WORKER ] = mock_producer
125+
126+ namespace .send_task (activation )
127+ assert mock_producer .produce .call_count == 1
124128
125- mock_call = mock_producer .produce .call_args
126- assert mock_call [0 ][0 ].name == "task-worker"
129+ mock_call = mock_producer .produce .call_args
130+ assert mock_call [0 ][0 ].name == "task-worker"
127131
128- proto_message = mock_call [0 ][1 ].value
129- assert proto_message == activation .SerializeToString ()
132+ proto_message = mock_call [0 ][1 ].value
133+ assert proto_message == activation .SerializeToString ()
130134
131135
136+ @pytest .mark .django_db
132137def test_namespace_send_task_with_retry () -> None :
133138 namespace = TaskNamespace (
134139 name = "tests" ,
135- topic = Topic . TASK_WORKER ,
140+ router = DefaultRouter () ,
136141 retry = None ,
137142 )
138143
@@ -147,19 +152,22 @@ def simple_task() -> None:
147152 assert activation .retry_state .max_attempts == 3
148153 assert activation .retry_state .on_attempts_exceeded == ON_ATTEMPTS_EXCEEDED_DEADLETTER
149154
150- with patch .object (namespace , "_producer" ) as mock_producer :
151- namespace .send_task (activation )
152- assert mock_producer .produce .call_count == 1
155+ mock_producer = Mock ()
156+ namespace ._producers [Topic .TASK_WORKER ] = mock_producer
157+
158+ namespace .send_task (activation )
159+ assert mock_producer .produce .call_count == 1
153160
154- mock_call = mock_producer .produce .call_args
155- proto_message = mock_call [0 ][1 ].value
156- assert proto_message == activation .SerializeToString ()
161+ mock_call = mock_producer .produce .call_args
162+ proto_message = mock_call [0 ][1 ].value
163+ assert proto_message == activation .SerializeToString ()
157164
158165
166+ @pytest .mark .django_db
159167def test_namespace_with_retry_send_task () -> None :
160168 namespace = TaskNamespace (
161169 name = "tests" ,
162- topic = Topic . TASK_WORKER ,
170+ router = DefaultRouter () ,
163171 retry = Retry (times = 3 ),
164172 )
165173
@@ -172,21 +180,24 @@ def simple_task() -> None:
172180 assert activation .retry_state .max_attempts == 3
173181 assert activation .retry_state .on_attempts_exceeded == ON_ATTEMPTS_EXCEEDED_DEADLETTER
174182
175- with patch .object (namespace , "_producer" ) as mock_producer :
176- namespace .send_task (activation )
177- assert mock_producer .produce .call_count == 1
183+ mock_producer = Mock ()
184+ namespace ._producers [Topic .TASK_WORKER ] = mock_producer
185+
186+ namespace .send_task (activation )
187+ assert mock_producer .produce .call_count == 1
178188
179- mock_call = mock_producer .produce .call_args
180- assert mock_call [0 ][0 ].name == "task-worker"
189+ mock_call = mock_producer .produce .call_args
190+ assert mock_call [0 ][0 ].name == "task-worker"
181191
182- proto_message = mock_call [0 ][1 ].value
183- assert proto_message == activation .SerializeToString ()
192+ proto_message = mock_call [0 ][1 ].value
193+ assert proto_message == activation .SerializeToString ()
184194
185195
196+ @pytest .mark .django_db
186197def test_namespace_with_wait_for_delivery_send_task () -> None :
187198 namespace = TaskNamespace (
188199 name = "tests" ,
189- topic = Topic . TASK_WORKER ,
200+ router = DefaultRouter () ,
190201 retry = Retry (times = 3 ),
191202 )
192203
@@ -196,18 +207,20 @@ def simple_task() -> None:
196207
197208 activation = simple_task .create_activation ()
198209
199- with patch .object (namespace , "_producer" ) as mock_producer :
200- ret_value : Future [None ] = Future ()
201- ret_value .set_result (None )
202- mock_producer .produce .return_value = ret_value
203- namespace .send_task (activation , wait_for_delivery = True )
204- assert mock_producer .produce .call_count == 1
210+ mock_producer = Mock ()
211+ namespace ._producers [Topic .TASK_WORKER ] = mock_producer
212+
213+ ret_value : Future [None ] = Future ()
214+ ret_value .set_result (None )
215+ mock_producer .produce .return_value = ret_value
216+ namespace .send_task (activation , wait_for_delivery = True )
217+ assert mock_producer .produce .call_count == 1
205218
206- mock_call = mock_producer .produce .call_args
207- assert mock_call [0 ][0 ].name == "task-worker"
219+ mock_call = mock_producer .produce .call_args
220+ assert mock_call [0 ][0 ].name == "task-worker"
208221
209- proto_message = mock_call [0 ][1 ].value
210- assert proto_message == activation .SerializeToString ()
222+ proto_message = mock_call [0 ][1 ].value
223+ assert proto_message == activation .SerializeToString ()
211224
212225
213226@pytest .mark .django_db
@@ -217,7 +230,7 @@ def test_registry_get() -> None:
217230
218231 assert isinstance (ns , TaskNamespace )
219232 assert ns .name == "tests"
220- assert ns .topic
233+ assert ns .router
221234 assert ns == registry .get ("tests" )
222235
223236 with pytest .raises (KeyError ):
@@ -284,4 +297,6 @@ def test_registry_create_namespace_route_setting() -> None:
284297 assert profiling .topic == Topic .PROFILES
285298
286299 with pytest .raises (ValueError ):
287- registry .create_namespace (name = "lol" )
300+ ns = registry .create_namespace (name = "lol" )
301+ # Should raise as the name is routed to an invalid topic
302+ ns .topic
0 commit comments