Skip to content

Commit d2bbd02

Browse files
authored
expose current_task_is_batch (#362)
1 parent e1c86aa commit d2bbd02

File tree

4 files changed

+49
-1
lines changed

4 files changed

+49
-1
lines changed

tasktiger/runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def run_eager_task(self, task: "Task") -> None:
4242
"""
4343
raise NotImplementedError("Eager tasks are not supported.")
4444

45-
def on_permanent_error(self, task: "Task", execution: Dict[str, Any] | None) -> None:
45+
def on_permanent_error(
46+
self, task: "Task", execution: Dict[str, Any] | None
47+
) -> None:
4648
"""
4749
Called if the task fails permanently.
4850

tasktiger/tasktiger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,19 @@ def _get_current_serialized_func(self) -> str:
311311
raise RuntimeError("Must be accessed from within a task.")
312312
return g["current_tasks"][0].serialized_func
313313

314+
def _get_current_task_is_batch(self) -> bool:
315+
if g["current_task_is_batch"] is None:
316+
raise RuntimeError("Must be accessed from within a task.")
317+
return g["current_task_is_batch"]
318+
314319
"""
315320
Properties to access the currently processing task (or tasks, in case of a
316321
batch task) from within the task. They must be invoked from within a task.
317322
"""
318323
current_task = property(_get_current_task)
319324
current_tasks = property(_get_current_tasks)
320325
current_serialized_func = property(_get_current_serialized_func)
326+
current_task_is_batch = property(_get_current_task_is_batch)
321327

322328
@classproperty
323329
def current_instance(self) -> "TaskTiger":

tests/tasks.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ def verify_current_serialized_func_batch(tasks):
163163
conn.set("serialized_func", serialized_func)
164164

165165

166+
def verify_current_task_is_batch():
167+
with redis.Redis(host=REDIS_HOST, db=TEST_DB, decode_responses=True) as conn:
168+
is_batch = tiger.current_task_is_batch
169+
conn.set("current_task_is_batch", str(is_batch))
170+
171+
172+
@tiger.task(batch=True, queue="batch")
173+
def verify_current_task_is_batch_batch(tasks):
174+
with redis.Redis(host=REDIS_HOST, db=TEST_DB, decode_responses=True) as conn:
175+
is_batch = tiger.current_task_is_batch
176+
conn.set("current_task_is_batch", str(is_batch))
177+
178+
166179
@tiger.task()
167180
def verify_tasktiger_instance():
168181
# Not necessarily the same object, but the same configuration.

tests/test_base.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
verify_current_serialized_func,
5050
verify_current_serialized_func_batch,
5151
verify_current_task,
52+
verify_current_task_is_batch,
53+
verify_current_task_is_batch_batch,
5254
verify_current_tasks,
5355
verify_tasktiger_instance,
5456
)
@@ -1172,6 +1174,31 @@ def test_current_serialized_func_batch(self, always_eager):
11721174
)
11731175

11741176

1177+
class TestCurrentTaskIsBatch(BaseTestCase):
1178+
"""
1179+
Ensure current_task_is_batch is set.
1180+
"""
1181+
1182+
@pytest.mark.parametrize("always_eager", [False, True])
1183+
def test_current_task_is_batch(self, always_eager):
1184+
self.tiger.config["ALWAYS_EAGER"] = always_eager
1185+
task = Task(self.tiger, verify_current_task_is_batch)
1186+
task.delay()
1187+
Worker(self.tiger).run(once=True)
1188+
assert not self.conn.exists("runtime_error")
1189+
assert self.conn.get("current_task_is_batch") == "False"
1190+
1191+
@pytest.mark.parametrize("always_eager", [False, True])
1192+
def test_current_task_is_batch_batch(self, always_eager):
1193+
self.tiger.config["ALWAYS_EAGER"] = always_eager
1194+
task1 = Task(self.tiger, verify_current_task_is_batch_batch)
1195+
task1.delay()
1196+
task2 = Task(self.tiger, verify_current_task_is_batch_batch)
1197+
task2.delay()
1198+
Worker(self.tiger).run(once=True)
1199+
assert self.conn.get("current_task_is_batch") == "True"
1200+
1201+
11751202
class TestTaskTigerGlobal(BaseTestCase):
11761203
"""
11771204
Ensure TaskTiger.current_instance is set.

0 commit comments

Comments
 (0)