Skip to content

Commit 7b365fd

Browse files
committed
fix for support generator
1 parent 264bfb8 commit 7b365fd

File tree

7 files changed

+143
-23
lines changed

7 files changed

+143
-23
lines changed

README.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,6 @@ from executor.engine import Engine, ProcessJob
7777
def add(a, b):
7878
return a + b
7979

80-
async def stream():
81-
for i in range(5):
82-
await asyncio.sleep(0.5)
83-
yield i
84-
8580
with Engine() as engine:
8681
# job1 and job2 will be executed in parallel
8782
job1 = ProcessJob(add, args=(1, 2))
@@ -92,12 +87,6 @@ with Engine() as engine:
9287
engine.wait_job(job3) # wait for job3 done
9388
print(job3.result()) # 10
9489

95-
# generator
96-
job4 = ProcessJob(stream)
97-
# do not do engine.wait because the generator job's future is done only when StopIteration
98-
await engine.submit_async(job4)
99-
async for x in job3.result():
100-
print(x)
10190
```
10291

10392
Async mode example:
@@ -111,6 +100,11 @@ engine = Engine()
111100
def add(a, b):
112101
return a + b
113102

103+
async def stream():
104+
for i in range(5):
105+
await asyncio.sleep(0.5)
106+
yield i
107+
114108
async def main():
115109
job1 = ProcessJob(add, args=(1, 2))
116110
job2 = ProcessJob(add, args=(job1.future, 4))
@@ -119,6 +113,13 @@ async def main():
119113
print(job1.result()) # 3
120114
print(job2.result()) # 7
121115

116+
# generator
117+
job3 = ProcessJob(stream)
118+
# do not do engine.wait because the generator job's future is done only when StopIteration
119+
await engine.submit_async(job3)
120+
async for x in job3.result():
121+
print(x)
122+
122123
asyncio.run(main())
123124
# or just `await main()` in jupyter environment
124125
```

docs/getting-started.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,48 @@ with Engine() as engine:
319319
engine.wait()
320320
```
321321

322+
### Generator support
323+
324+
`executor.engine` supports generator job, which is a special job that returns a generator.
325+
326+
```python
327+
import asyncio
328+
from executor.engine import Engine, ProcessJob
329+
330+
engine = Engine()
331+
332+
def gen():
333+
for i in range(10):
334+
yield i
335+
336+
async def async_gen():
337+
for i in range(10):
338+
await asyncio.sleep(0.5)
339+
yield i
340+
341+
async def main():
342+
job = ProcessJob(gen)
343+
await engine.submit_async(job)
344+
await job.wait_until_status("running")
345+
for i in job.result():
346+
print(i)
347+
348+
job = ProcessJob(async_gen)
349+
await engine.submit_async(job)
350+
await job.wait_until_status("running")
351+
async for i in job.result():
352+
print(i)
353+
354+
asyncio.run(main())
355+
```
356+
357+
!!! info
358+
`LocalJob`, `ThreadJob`, `ProcessJob`, `DaskJob` support generator job.
359+
360+
!!! warning
361+
Do not use `engine.wait()` to wait the generator job done,
362+
because the generator job's future is done only when the generator is exhausted.
363+
322364
## Engine
323365

324366
`executor.engine` provides a `Engine` class for managing jobs.

executor/engine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .core import Engine, EngineSetting
22
from .job import LocalJob, ThreadJob, ProcessJob
33

4-
__version__ = '0.2.5'
4+
__version__ = '0.2.6'
55

66
__all__ = [
77
'Engine', 'EngineSetting',

executor/engine/job/local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .base import Job
2+
from .utils import GeneratorWrapper
23

34

45
class LocalJob(Job):
@@ -9,5 +10,4 @@ async def run_function(self):
910

1011
async def run_generator(self):
1112
"""Run job as a generator."""
12-
res = self.func(*self.args, **self.kwargs)
13-
return res
13+
return GeneratorWrapper(self)

executor/engine/job/utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,25 @@ class GeneratorWrapper(T.Generic[_T]):
7474
def __init__(self, job: "Job", fut: T.Optional[Future] = None):
7575
self._job = job
7676
self._fut = fut
77+
self._local_res = None
7778

7879
def __iter__(self):
7980
return self
8081

8182
def __next__(self) -> _T:
8283
try:
83-
return self._job._executor.submit(_gen_next, self._fut).result()
84+
if self._job._executor is not None:
85+
return self._job._executor.submit(
86+
_gen_next, self._fut).result()
87+
else:
88+
if self._local_res is None:
89+
self._local_res = self._job.func(
90+
*self._job.args, **self._job.kwargs)
91+
return next(self._local_res)
8492
except Exception as e:
8593
engine = self._job.engine
8694
if engine is None:
87-
loop = asyncio.get_event_loop()
95+
loop = asyncio.get_event_loop() # pragma: no cover
8896
else:
8997
loop = engine.loop
9098
if isinstance(e, StopIteration):
@@ -100,9 +108,15 @@ def __aiter__(self):
100108

101109
async def __anext__(self) -> _T:
102110
try:
103-
fut = self._job._executor.submit(_gen_anext, self._fut)
104-
res = await asyncio.wrap_future(fut)
105-
return res
111+
if self._job._executor is not None:
112+
fut = self._job._executor.submit(_gen_anext, self._fut)
113+
res = await asyncio.wrap_future(fut)
114+
return res
115+
else:
116+
if self._local_res is None:
117+
self._local_res = self._job.func(
118+
*self._job.args, **self._job.kwargs)
119+
return await self._local_res.__anext__()
106120
except Exception as e:
107121
if isinstance(e, StopAsyncIteration):
108122
await self._job.on_done(self)

tests/test_dask_job.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,22 @@ async def main():
7676
client.close()
7777

7878
asyncio.run(main())
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_dask_generator():
83+
with Engine() as engine:
84+
async def gen():
85+
for i in range(10):
86+
yield i
87+
88+
job = DaskJob(gen)
89+
await engine.submit_async(job)
90+
await job.wait_until_status("running")
91+
assert job.status == "running"
92+
g = job.result()
93+
i = 0
94+
async for x in g:
95+
assert x == i
96+
i += 1
97+
assert job.status == "done"

tests/test_job.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,26 +195,70 @@ def gen():
195195

196196
job = ProcessJob(gen)
197197
await engine.submit_async(job)
198-
await job.join()
198+
await job.wait_until_status("running")
199199
assert job.status == "running"
200200
g = job.result()
201201
assert list(g) == list(range(10))
202202
assert job.status == "done"
203203

204+
job = ThreadJob(gen)
205+
await engine.submit_async(job)
206+
await job.wait_until_status("running")
207+
assert job.status == "running"
208+
g = job.result()
209+
assert list(g) == list(range(10))
210+
assert job.status == "done"
211+
212+
job = LocalJob(gen)
213+
await engine.submit_async(job)
214+
await job.wait_until_status("running")
215+
assert job.status == "running"
216+
g = job.result()
217+
assert list(g) == list(range(10))
218+
assert job.status == "done"
219+
220+
221+
@pytest.mark.asyncio
222+
async def test_generator_async():
223+
with Engine() as engine:
204224
async def gen_async(n):
205225
for i in range(n):
206226
yield i
207227

208228
job = ProcessJob(gen_async, (10,))
209229
await engine.submit_async(job)
210-
await job.join()
230+
await job.wait_until_status("running")
231+
res = []
232+
async for i in job.result():
233+
assert job.status == "running"
234+
res.append(i)
235+
assert job.status == "done"
236+
assert res == list(range(10))
237+
238+
job = ThreadJob(gen_async, (10,))
239+
await engine.submit_async(job)
240+
await job.wait_until_status("running")
211241
res = []
212242
async for i in job.result():
213243
assert job.status == "running"
214244
res.append(i)
215245
assert job.status == "done"
216246
assert res == list(range(10))
217247

248+
job = LocalJob(gen_async, (10,))
249+
await engine.submit_async(job)
250+
await job.wait_until_status("running")
251+
res = []
252+
async for i in job.result():
253+
assert job.status == "running"
254+
res.append(i)
255+
assert job.status == "done"
256+
assert res == list(range(10))
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_generator_error():
261+
with Engine() as engine:
218262
def gen_error():
219263
for i in range(2):
220264
print(i)
@@ -223,7 +267,7 @@ def gen_error():
223267

224268
job = ProcessJob(gen_error)
225269
await engine.submit_async(job)
226-
await job.join()
270+
await job.wait_until_status("running")
227271
with pytest.raises(ValueError):
228272
for i in job.result():
229273
assert job.status == "running"
@@ -237,7 +281,7 @@ async def gen_error():
237281

238282
job = ProcessJob(gen_error)
239283
await engine.submit_async(job)
240-
await job.join()
284+
await job.wait_until_status("running")
241285
with pytest.raises(ValueError):
242286
async for i in job.result():
243287
assert job.status == "running"

0 commit comments

Comments
 (0)