11import logging
22import time
3- import typing
43import uuid
54from datetime import timedelta
65from threading import Thread
98from django .core .cache import cache
109from django .utils import timezone
1110from freezegun import freeze_time
11+ from pytest_django .fixtures import SettingsWrapper
1212from pytest_mock import MockerFixture
1313
1414from common .test_tools import AssertMetricFixture
3737
3838
3939@pytest .fixture (autouse = True )
40- def reset_cache () -> typing .Generator [None , None , None ]:
41- yield
40+ def reset_cache () -> None :
4241 cache .clear ()
4342
4443
@@ -73,38 +72,49 @@ def _sleep_task(seconds: int) -> None:
7372 return _sleep_task
7473
7574
75+ @pytest .mark .django_db (databases = ["default" , "task_processor" ])
7676@pytest .mark .task_processor_mode
77+ @pytest .mark .parametrize (
78+ "database" ,
79+ ["default" , "task_processor" ],
80+ )
7781def test_run_task_runs_task_and_creates_task_run_object_when_success (
82+ database : str ,
7883 dummy_task : TaskHandler [[str , str ]],
7984) -> None :
8085 # Given
81- task = Task .create (
82- dummy_task .task_identifier ,
83- scheduled_for = timezone .now (),
84- )
85- task .save ()
86+ task = Task .create (dummy_task .task_identifier , scheduled_for = timezone .now ())
87+ task .save (using = database )
8688
8789 # When
88- task_runs = run_tasks ("default" )
90+ task_runs = run_tasks (database )
8991
9092 # Then
9193 assert cache .get (DEFAULT_CACHE_KEY )
9294
93- assert len (task_runs ) == TaskRun .objects .filter (task = task ).count () == 1
95+ assert (
96+ len (task_runs ) == TaskRun .objects .using (database ).filter (task = task ).count () == 1
97+ )
9498 task_run = task_runs [0 ]
9599 assert task_run .result == TaskResult .SUCCESS .value
96100 assert task_run .started_at
97101 assert task_run .finished_at
98102 assert task_run .error_details is None
99103
100- task .refresh_from_db ()
104+ task .refresh_from_db (using = database )
101105 assert task .completed
102106
103107
108+ @pytest .mark .django_db (databases = ["default" , "task_processor" ])
104109@pytest .mark .task_processor_mode
110+ @pytest .mark .parametrize (
111+ "database" ,
112+ ["default" , "task_processor" ],
113+ )
105114def test_run_task_kills_task_after_timeout (
106- sleep_task : TaskHandler [[int ]],
107115 caplog : pytest .LogCaptureFixture ,
116+ database : str ,
117+ sleep_task : TaskHandler [[int ]],
108118) -> None :
109119 # Given
110120 task = Task .create (
@@ -113,21 +123,23 @@ def test_run_task_kills_task_after_timeout(
113123 args = (1 ,),
114124 timeout = timedelta (microseconds = 1 ),
115125 )
116- task .save ()
126+ task .save (using = database )
117127
118128 # When
119- task_runs = run_tasks ("default" )
129+ task_runs = run_tasks (database )
120130
121131 # Then
122- assert len (task_runs ) == TaskRun .objects .filter (task = task ).count () == 1
132+ assert (
133+ len (task_runs ) == TaskRun .objects .using (database ).filter (task = task ).count () == 1
134+ )
123135 task_run = task_runs [0 ]
124136 assert task_run .result == TaskResult .FAILURE .value
125137 assert task_run .started_at
126138 assert task_run .finished_at is None
127139 assert task_run .error_details
128140 assert "TimeoutError" in task_run .error_details
129141
130- task .refresh_from_db ()
142+ task .refresh_from_db (using = database )
131143
132144 assert task .completed is False
133145 assert task .num_failures == 1
@@ -139,12 +151,20 @@ def test_run_task_kills_task_after_timeout(
139151 )
140152
141153
142- @pytest .mark .django_db
154+ @pytest .mark .django_db ( databases = [ "default" , "task_processor" ])
143155@pytest .mark .task_processor_mode
156+ @pytest .mark .parametrize (
157+ "database" ,
158+ ["default" , "task_processor" ],
159+ )
144160def test_run_recurring_task_kills_task_after_timeout (
145161 caplog : pytest .LogCaptureFixture ,
162+ database : str ,
163+ settings : SettingsWrapper ,
146164) -> None :
147165 # Given
166+ settings .TASK_PROCESSOR_DATABASES = [database ]
167+
148168 @register_recurring_task (
149169 run_every = timedelta (seconds = 1 ), timeout = timedelta (microseconds = 1 )
150170 )
@@ -157,18 +177,22 @@ def _dummy_recurring_task() -> None:
157177 task_identifier = "test_unit_task_processor_processor._dummy_recurring_task" ,
158178 )
159179 # When
160- task_runs = run_recurring_tasks ("default" )
180+ task_runs = run_recurring_tasks (database )
161181
162182 # Then
163- assert len (task_runs ) == RecurringTaskRun .objects .filter (task = task ).count () == 1
183+ assert (
184+ len (task_runs )
185+ == RecurringTaskRun .objects .using (database ).filter (task = task ).count ()
186+ == 1
187+ )
164188 task_run = task_runs [0 ]
165189 assert task_run .result == TaskResult .FAILURE .value
166190 assert task_run .started_at
167191 assert task_run .finished_at is None
168192 assert task_run .error_details
169193 assert "TimeoutError" in task_run .error_details
170194
171- task .refresh_from_db ()
195+ task .refresh_from_db (using = database )
172196
173197 assert task .locked_at is None
174198 assert task .is_locked is False
@@ -179,6 +203,9 @@ def _dummy_recurring_task() -> None:
179203 )
180204
181205
206+ # TODO: Need to parametrize all/most tests below to run on both databases
207+
208+
182209@pytest .mark .django_db
183210@pytest .mark .task_processor_mode
184211def test_run_recurring_tasks_runs_task_and_creates_recurring_task_run_object_when_success () -> (
0 commit comments