|
49 | 49 | verify_current_serialized_func, |
50 | 50 | verify_current_serialized_func_batch, |
51 | 51 | verify_current_task, |
| 52 | + verify_current_task_is_batch, |
| 53 | + verify_current_task_is_batch_batch, |
52 | 54 | verify_current_tasks, |
53 | 55 | verify_tasktiger_instance, |
54 | 56 | ) |
@@ -1200,6 +1202,37 @@ def test_current_serialized_func_batch(self, always_eager): |
1200 | 1202 | ) |
1201 | 1203 |
|
1202 | 1204 |
|
| 1205 | +class TestCurrentTaskIsBatch(BaseTestCase): |
| 1206 | + """ |
| 1207 | + Ensure current_task_is_batch is set. |
| 1208 | + """ |
| 1209 | + |
| 1210 | + @pytest.mark.parametrize("always_eager", [False, True]) |
| 1211 | + def test_current_task_is_batch(self, always_eager): |
| 1212 | + self.tiger.config["ALWAYS_EAGER"] = always_eager |
| 1213 | + task = Task(self.tiger, verify_current_task_is_batch) |
| 1214 | + task.delay() |
| 1215 | + Worker(self.tiger).run(once=True) |
| 1216 | + assert not self.conn.exists("runtime_error") |
| 1217 | + assert ( |
| 1218 | + self.conn.get("current_task_is_batch") |
| 1219 | + == "False" |
| 1220 | + ) |
| 1221 | + |
| 1222 | + @pytest.mark.parametrize("always_eager", [False, True]) |
| 1223 | + def test_current_task_is_batch_batch(self, always_eager): |
| 1224 | + self.tiger.config["ALWAYS_EAGER"] = always_eager |
| 1225 | + task1 = Task(self.tiger, verify_current_task_is_batch_batch) |
| 1226 | + task1.delay() |
| 1227 | + task2 = Task(self.tiger, verify_current_task_is_batch_batch) |
| 1228 | + task2.delay() |
| 1229 | + Worker(self.tiger).run(once=True) |
| 1230 | + assert ( |
| 1231 | + self.conn.get("current_task_is_batch") |
| 1232 | + == "True" |
| 1233 | + ) |
| 1234 | + |
| 1235 | + |
1203 | 1236 | class TestTaskTigerGlobal(BaseTestCase): |
1204 | 1237 | """ |
1205 | 1238 | Ensure TaskTiger.current_instance is set. |
|
0 commit comments