Skip to content

Commit 443c097

Browse files
committed
expose current_task_is_batch
1 parent 472c913 commit 443c097

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

tasktiger/tasktiger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,19 @@ def _get_current_serialized_func(self) -> str:
317317
raise RuntimeError("Must be accessed from within a task.")
318318
return g["current_tasks"][0].serialized_func
319319

320+
def _get_current_task_is_batch(self) -> bool:
321+
if g["current_task_is_batch"] is None:
322+
raise RuntimeError("Must be accessed from within a task.")
323+
return g["current_task_is_batch"]
324+
320325
"""
321326
Properties to access the currently processing task (or tasks, in case of a
322327
batch task) from within the task. They must be invoked from within a task.
323328
"""
324329
current_task = property(_get_current_task)
325330
current_tasks = property(_get_current_tasks)
326331
current_serialized_func = property(_get_current_serialized_func)
332+
current_task_is_batch = property(_get_current_task_is_batch)
327333

328334
@classproperty
329335
def current_instance(self) -> "TaskTiger":

tests/tasks.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,23 @@ def verify_current_serialized_func_batch(tasks):
183183
conn.set("serialized_func", serialized_func)
184184

185185

186+
def verify_current_task_is_batch():
187+
with redis.Redis(
188+
host=REDIS_HOST, db=TEST_DB, decode_responses=True
189+
) as conn:
190+
is_batch = tiger.current_task_is_batch
191+
conn.set("current_task_is_batch", str(is_batch))
192+
193+
194+
@tiger.task(batch=True, queue="batch")
195+
def verify_current_task_is_batch_batch(tasks):
196+
with redis.Redis(
197+
host=REDIS_HOST, db=TEST_DB, decode_responses=True
198+
) as conn:
199+
is_batch = tiger.current_task_is_batch
200+
conn.set("current_task_is_batch", str(is_batch))
201+
202+
186203
@tiger.task()
187204
def verify_tasktiger_instance():
188205
# Not necessarily the same object, but the same configuration.

tests/test_base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
verify_current_serialized_func,
5050
verify_current_serialized_func_batch,
5151
verify_current_task,
52+
verify_current_task_is_batch,
53+
verify_current_task_is_batch_batch,
5254
verify_current_tasks,
5355
verify_tasktiger_instance,
5456
)
@@ -1200,6 +1202,37 @@ def test_current_serialized_func_batch(self, always_eager):
12001202
)
12011203

12021204

1205+
class TestCurrentTaskIsBatch(BaseTestCase):
1206+
"""
1207+
Ensure current_task_is_batch is set.
1208+
"""
1209+
1210+
@pytest.mark.parametrize("always_eager", [False, True])
1211+
def test_current_task_is_batch(self, always_eager):
1212+
self.tiger.config["ALWAYS_EAGER"] = always_eager
1213+
task = Task(self.tiger, verify_current_task_is_batch)
1214+
task.delay()
1215+
Worker(self.tiger).run(once=True)
1216+
assert not self.conn.exists("runtime_error")
1217+
assert (
1218+
self.conn.get("current_task_is_batch")
1219+
== "False"
1220+
)
1221+
1222+
@pytest.mark.parametrize("always_eager", [False, True])
1223+
def test_current_task_is_batch_batch(self, always_eager):
1224+
self.tiger.config["ALWAYS_EAGER"] = always_eager
1225+
task1 = Task(self.tiger, verify_current_task_is_batch_batch)
1226+
task1.delay()
1227+
task2 = Task(self.tiger, verify_current_task_is_batch_batch)
1228+
task2.delay()
1229+
Worker(self.tiger).run(once=True)
1230+
assert (
1231+
self.conn.get("current_task_is_batch")
1232+
== "True"
1233+
)
1234+
1235+
12031236
class TestTaskTigerGlobal(BaseTestCase):
12041237
"""
12051238
Ensure TaskTiger.current_instance is set.

0 commit comments

Comments
 (0)