Skip to content

Commit 264bfb8

Browse files
authored
Merge pull request #6 from liunux4odoo/feat
support generator for LocalJob & ThreadJob & DaskJob
2 parents 389af29 + d491db0 commit 264bfb8

File tree

5 files changed

+65
-16
lines changed

5 files changed

+65
-16
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ from executor.engine import Engine, ProcessJob
7777
def add(a, b):
7878
return a + b
7979

80+
async def stream():
81+
for i in range(5):
82+
await asyncio.sleep(0.5)
83+
yield i
84+
8085
with Engine() as engine:
8186
# job1 and job2 will be executed in parallel
8287
job1 = ProcessJob(add, args=(1, 2))
@@ -86,6 +91,13 @@ with Engine() as engine:
8691
engine.submit(job1, job2, job3)
8792
engine.wait_job(job3) # wait for job3 done
8893
print(job3.result()) # 10
94+
95+
# generator
96+
job4 = ProcessJob(stream)
97+
# do not do engine.wait because the generator job's future is done only when StopIteration
98+
await engine.submit_async(job4)
99+
async for x in job3.result():
100+
print(x)
89101
```
90102

91103
Async mode example:

executor/engine/job/dask.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dask.distributed import Client, LocalCluster
44

55
from .base import Job
6+
from .utils import GeneratorWrapper
67
from ..utils import PortManager
78

89

@@ -56,12 +57,21 @@ def release_resource(self) -> bool:
5657
async def run_function(self):
5758
"""Run job with Dask."""
5859
client = self.engine.dask_client
59-
func = functools.partial(self.func, **self.kwargs)
60-
fut = client.submit(func, *self.args)
60+
func = functools.partial(self.func, *self.args, **self.kwargs)
61+
fut = client.submit(func)
6162
self._executor = fut
6263
result = await fut
6364
return result
6465

66+
async def run_generator(self):
67+
"""Run job as a generator."""
68+
client = self.engine.dask_client
69+
func = functools.partial(self.func, *self.args, **self.kwargs)
70+
fut = client.submit(func)
71+
self._executor = client.get_executor(pure=False)
72+
result = GeneratorWrapper(self, fut)
73+
return result
74+
6575
async def cancel(self):
6676
"""Cancel job."""""
6777
if self.status == "running":

executor/engine/job/local.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@ async def run_function(self):
66
"""Run job in local thread."""
77
res = self.func(*self.args, **self.kwargs)
88
return res
9+
10+
async def run_generator(self):
11+
"""Run job as a generator."""
12+
res = self.func(*self.args, **self.kwargs)
13+
return res

executor/engine/job/thread.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from concurrent.futures import ThreadPoolExecutor
44

55
from .base import Job
6+
from .utils import _gen_initializer, GeneratorWrapper
67

78

89
class ThreadJob(Job):
@@ -42,13 +43,21 @@ def release_resource(self) -> bool:
4243

4344
async def run_function(self):
4445
"""Run job in thread pool."""
46+
func = functools.partial(self.func, *self.args, **self.kwargs)
4547
self._executor = ThreadPoolExecutor(1)
4648
loop = asyncio.get_running_loop()
47-
func = functools.partial(self.func, **self.kwargs)
48-
fut = loop.run_in_executor(self._executor, func, *self.args)
49+
fut = loop.run_in_executor(self._executor, func)
4950
result = await fut
5051
return result
5152

53+
async def run_generator(self):
54+
"""Run job as a generator."""
55+
func = functools.partial(self.func, *self.args, **self.kwargs)
56+
self._executor = ThreadPoolExecutor(
57+
1, initializer=_gen_initializer, initargs=(func,))
58+
result = GeneratorWrapper(self)
59+
return result
60+
5261
async def cancel(self):
5362
"""Cancel job."""
5463
if self.status == "running":

executor/engine/job/utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import typing as T
22
import asyncio
33
from datetime import datetime
4+
from concurrent.futures import Future
5+
import threading
46

57
from ..utils import CheckAttrRange, ExecutorError
68

79

810
if T.TYPE_CHECKING:
911
from .base import Job
1012

11-
1213
JobStatusType = T.Literal['pending', 'running', 'failed', 'done', 'cancelled']
1314
valid_job_statuses: T.List[JobStatusType] = [
1415
'pending', 'running', 'failed', 'done', 'cancelled']
@@ -38,36 +39,48 @@ def __init__(self, job: "Job", valid_status: T.List[JobStatusType]):
3839

3940

4041
_T = T.TypeVar("_T")
42+
_thread_locals = threading.local()
4143

4244

4345
def _gen_initializer(gen_func, args=tuple(), kwargs={}): # pragma: no cover
44-
global _generator
45-
_generator = gen_func(*args, **kwargs)
46+
global _thread_locals
47+
if "_thread_locals" not in globals():
48+
# avoid conflict for ThreadJob
49+
_thread_locals = threading.local()
50+
_thread_locals._generator = gen_func(*args, **kwargs)
4651

4752

48-
def _gen_next(): # pragma: no cover
49-
global _generator
50-
return next(_generator)
53+
def _gen_next(fut=None): # pragma: no cover
54+
global _thread_locals
55+
if fut is None:
56+
return next(_thread_locals._generator)
57+
else:
58+
return next(fut)
5159

5260

53-
def _gen_anext(): # pragma: no cover
54-
global _generator
55-
return asyncio.run(_generator.__anext__())
61+
def _gen_anext(fut=None): # pragma: no cover
62+
global _thread_locals
63+
if fut is None:
64+
return asyncio.run(_thread_locals._generator.__anext__())
65+
else:
66+
return asyncio.run(fut.__anext__())
5667

5768

5869
class GeneratorWrapper(T.Generic[_T]):
5970
"""
6071
wrap a generator in executor pool
6172
"""
62-
def __init__(self, job: "Job"):
73+
74+
def __init__(self, job: "Job", fut: T.Optional[Future] = None):
6375
self._job = job
76+
self._fut = fut
6477

6578
def __iter__(self):
6679
return self
6780

6881
def __next__(self) -> _T:
6982
try:
70-
return self._job._executor.submit(_gen_next).result()
83+
return self._job._executor.submit(_gen_next, self._fut).result()
7184
except Exception as e:
7285
engine = self._job.engine
7386
if engine is None:
@@ -87,7 +100,7 @@ def __aiter__(self):
87100

88101
async def __anext__(self) -> _T:
89102
try:
90-
fut = self._job._executor.submit(_gen_anext)
103+
fut = self._job._executor.submit(_gen_anext, self._fut)
91104
res = await asyncio.wrap_future(fut)
92105
return res
93106
except Exception as e:

0 commit comments

Comments
 (0)