|
49 | 49 | verify_current_serialized_func, |
50 | 50 | verify_current_serialized_func_batch, |
51 | 51 | verify_current_task, |
| 52 | + verify_current_task_is_batch, |
| 53 | + verify_current_task_is_batch_batch, |
52 | 54 | verify_current_tasks, |
53 | 55 | verify_tasktiger_instance, |
54 | 56 | ) |
@@ -1172,6 +1174,31 @@ def test_current_serialized_func_batch(self, always_eager): |
1172 | 1174 | ) |
1173 | 1175 |
|
1174 | 1176 |
|
| 1177 | +class TestCurrentTaskIsBatch(BaseTestCase): |
| 1178 | + """ |
| 1179 | + Ensure current_task_is_batch is set. |
| 1180 | + """ |
| 1181 | + |
| 1182 | + @pytest.mark.parametrize("always_eager", [False, True]) |
| 1183 | + def test_current_task_is_batch(self, always_eager): |
| 1184 | + self.tiger.config["ALWAYS_EAGER"] = always_eager |
| 1185 | + task = Task(self.tiger, verify_current_task_is_batch) |
| 1186 | + task.delay() |
| 1187 | + Worker(self.tiger).run(once=True) |
| 1188 | + assert not self.conn.exists("runtime_error") |
| 1189 | + assert self.conn.get("current_task_is_batch") == "False" |
| 1190 | + |
| 1191 | + @pytest.mark.parametrize("always_eager", [False, True]) |
| 1192 | + def test_current_task_is_batch_batch(self, always_eager): |
| 1193 | + self.tiger.config["ALWAYS_EAGER"] = always_eager |
| 1194 | + task1 = Task(self.tiger, verify_current_task_is_batch_batch) |
| 1195 | + task1.delay() |
| 1196 | + task2 = Task(self.tiger, verify_current_task_is_batch_batch) |
| 1197 | + task2.delay() |
| 1198 | + Worker(self.tiger).run(once=True) |
| 1199 | + assert self.conn.get("current_task_is_batch") == "True" |
| 1200 | + |
| 1201 | + |
1175 | 1202 | class TestTaskTigerGlobal(BaseTestCase): |
1176 | 1203 | """ |
1177 | 1204 | Ensure TaskTiger.current_instance is set. |
|
0 commit comments