Skip to content

Commit 669d71f

Browse files
committed
support generator send method
1 parent 02db959 commit 669d71f

File tree

8 files changed

+168
-21
lines changed

8 files changed

+168
-21
lines changed

docs/getting-started.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,33 @@ asyncio.run(main())
361361
Do not use `engine.wait()` to wait the generator job done,
362362
because the generator job's future is done only when the generator is exhausted.
363363

364+
Generator support the `send` method, you can use this feature to pass data to the generator, it allow you communicate with another thread/process:
365+
366+
```python
367+
import asyncio
368+
from executor.engine import Engine, ProcessJob
369+
370+
def calculator():
371+
res = None
372+
while True:
373+
expr = yield res
374+
res = eval(expr)
375+
376+
377+
async def main():
378+
with Engine() as engine:
379+
job = ProcessJob(calculator)
380+
await engine.submit_async(job)
381+
await job.wait_until_status("running")
382+
g = job.result()
383+
g.send(None) # initialize the generator
384+
print(g.send("1 + 2")) # 3
385+
print(g.send("3 * 4")) # 12
386+
print(g.send("(1 + 2) * 4")) # 12
387+
388+
asyncio.run(main())
389+
```
390+
364391
## Engine
365392

366393
`executor.engine` provides a `Engine` class for managing jobs.

executor/engine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .core import Engine, EngineSetting
22
from .job import LocalJob, ThreadJob, ProcessJob
33

4-
__version__ = '0.2.6'
4+
__version__ = '0.2.7'
55

66
__all__ = [
77
'Engine', 'EngineSetting',

executor/engine/job/dask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dask.distributed import Client, LocalCluster
44

55
from .base import Job
6-
from .utils import GeneratorWrapper
6+
from .utils import create_generator_wrapper
77
from ..utils import PortManager
88

99

@@ -69,7 +69,7 @@ async def run_generator(self):
6969
func = functools.partial(self.func, *self.args, **self.kwargs)
7070
fut = client.submit(func)
7171
self._executor = client.get_executor(pure=False)
72-
result = GeneratorWrapper(self, fut)
72+
result = create_generator_wrapper(self, fut)
7373
return result
7474

7575
async def cancel(self):

executor/engine/job/local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .base import Job
2-
from .utils import GeneratorWrapper
2+
from .utils import create_generator_wrapper
33

44

55
class LocalJob(Job):
@@ -10,4 +10,4 @@ async def run_function(self):
1010

1111
async def run_generator(self):
1212
"""Run job as a generator."""
13-
return GeneratorWrapper(self)
13+
return create_generator_wrapper(self)

executor/engine/job/process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from loky.process_executor import ProcessPoolExecutor
55

66
from .base import Job
7-
from .utils import _gen_initializer, GeneratorWrapper
7+
from .utils import _gen_initializer, create_generator_wrapper
88

99

1010
class ProcessJob(Job):
@@ -56,7 +56,7 @@ async def run_generator(self):
5656
func = functools.partial(self.func, *self.args, **self.kwargs)
5757
self._executor = ProcessPoolExecutor(
5858
1, initializer=_gen_initializer, initargs=(func,))
59-
result = GeneratorWrapper(self)
59+
result = create_generator_wrapper(self)
6060
return result
6161

6262
async def cancel(self):

executor/engine/job/thread.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from concurrent.futures import ThreadPoolExecutor
44

55
from .base import Job
6-
from .utils import _gen_initializer, GeneratorWrapper
6+
from .utils import _gen_initializer, create_generator_wrapper
77

88

99
class ThreadJob(Job):
@@ -55,7 +55,7 @@ async def run_generator(self):
5555
func = functools.partial(self.func, *self.args, **self.kwargs)
5656
self._executor = ThreadPoolExecutor(
5757
1, initializer=_gen_initializer, initargs=(func,))
58-
result = GeneratorWrapper(self)
58+
result = create_generator_wrapper(self)
5959
return result
6060

6161
async def cancel(self):

executor/engine/job/utils.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing as T
22
import asyncio
3+
import inspect
34
from datetime import datetime
45
from concurrent.futures import Future
56
import threading
@@ -49,20 +50,28 @@ def _gen_initializer(gen_func, args=tuple(), kwargs={}): # pragma: no cover
4950
_thread_locals._generator = gen_func(*args, **kwargs)
5051

5152

52-
def _gen_next(fut=None): # pragma: no cover
53+
def _gen_next(send_value=None, fut=None): # pragma: no cover
5354
global _thread_locals
5455
if fut is None:
55-
return next(_thread_locals._generator)
56+
g = _thread_locals._generator
5657
else:
57-
return next(fut)
58+
g = fut
59+
if send_value is None:
60+
return next(g)
61+
else:
62+
return g.send(send_value)
5863

5964

60-
def _gen_anext(fut=None): # pragma: no cover
65+
def _gen_anext(send_value=None, fut=None): # pragma: no cover
6166
global _thread_locals
6267
if fut is None:
63-
return asyncio.run(_thread_locals._generator.__anext__())
68+
g = _thread_locals._generator
69+
else:
70+
g = fut
71+
if send_value is None:
72+
return asyncio.run(g.__anext__())
6473
else:
65-
return asyncio.run(fut.__anext__())
74+
return asyncio.run(g.asend(send_value))
6675

6776

6877
class GeneratorWrapper():
@@ -75,19 +84,28 @@ def __init__(self, job: "Job", fut: T.Optional[Future] = None):
7584
self._fut = fut
7685
self._local_res = None
7786

87+
88+
class SyncGeneratorWrapper(GeneratorWrapper):
89+
"""
90+
wrap a generator in executor pool
91+
"""
7892
def __iter__(self):
7993
return self
8094

81-
def __next__(self):
95+
def _next(self, send_value=None):
8296
try:
8397
if self._job._executor is not None:
8498
return self._job._executor.submit(
85-
_gen_next, self._fut).result()
99+
_gen_next, send_value, self._fut).result()
86100
else:
101+
# create local generator
87102
if self._local_res is None:
88103
self._local_res = self._job.func(
89104
*self._job.args, **self._job.kwargs)
90-
return next(self._local_res)
105+
if send_value is not None:
106+
return self._local_res.send(send_value)
107+
else:
108+
return next(self._local_res)
91109
except Exception as e:
92110
engine = self._job.engine
93111
if engine is None:
@@ -102,23 +120,52 @@ def __next__(self):
102120
fut.result()
103121
raise e
104122

123+
def __next__(self):
124+
return self._next()
125+
126+
def send(self, value):
127+
return self._next(value)
128+
129+
130+
class AsyncGeneratorWrapper(GeneratorWrapper):
131+
"""
132+
wrap a generator in executor pool
133+
"""
105134
def __aiter__(self):
106135
return self
107136

108-
async def __anext__(self):
137+
async def _anext(self, send_value=None):
109138
try:
110139
if self._job._executor is not None:
111-
fut = self._job._executor.submit(_gen_anext, self._fut)
140+
fut = self._job._executor.submit(
141+
_gen_anext, send_value, self._fut)
112142
res = await asyncio.wrap_future(fut)
113143
return res
114144
else:
115145
if self._local_res is None:
116146
self._local_res = self._job.func(
117147
*self._job.args, **self._job.kwargs)
118-
return await self._local_res.__anext__()
148+
if send_value is not None:
149+
return await self._local_res.asend(send_value)
150+
else:
151+
return await self._local_res.__anext__()
119152
except Exception as e:
120153
if isinstance(e, StopAsyncIteration):
121154
await self._job.on_done(self)
122155
else:
123156
await self._job.on_failed(e)
124157
raise e
158+
159+
async def __anext__(self):
160+
return await self._anext()
161+
162+
async def asend(self, value):
163+
return await self._anext(value)
164+
165+
166+
def create_generator_wrapper(
167+
job: "Job", fut: T.Optional[Future] = None) -> GeneratorWrapper:
168+
if inspect.isasyncgenfunction(job.func):
169+
return AsyncGeneratorWrapper(job, fut)
170+
else:
171+
return SyncGeneratorWrapper(job, fut)

tests/test_job.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,76 @@ async def gen_error():
286286
async for i in job.result():
287287
assert job.status == "running"
288288
assert job.status == "failed"
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_generator_send():
293+
with Engine() as engine:
294+
def gen():
295+
res = 0
296+
for _ in range(3):
297+
res += yield res
298+
299+
job = ProcessJob(gen)
300+
await engine.submit_async(job)
301+
await job.wait_until_status("running")
302+
assert job.status == "running"
303+
g = job.result()
304+
assert g.send(None) == 0
305+
assert g.send(1) == 1
306+
assert g.send(2) == 3
307+
with pytest.raises(StopIteration):
308+
g.send(3)
309+
assert job.status == "done"
310+
311+
async def gen_async():
312+
res = 0
313+
for _ in range(3):
314+
res += yield res
315+
316+
job = ProcessJob(gen_async)
317+
await engine.submit_async(job)
318+
await job.wait_until_status("running")
319+
assert job.status == "running"
320+
g = job.result()
321+
assert await g.asend(None) == 0
322+
assert await g.asend(1) == 1
323+
assert await g.asend(2) == 3
324+
with pytest.raises(StopAsyncIteration):
325+
await g.asend(3)
326+
assert job.status == "done"
327+
328+
329+
@pytest.mark.asyncio
330+
async def test_generator_send_localjob():
331+
with Engine() as engine:
332+
def gen():
333+
res = 0
334+
for _ in range(3):
335+
res += yield res
336+
337+
job = LocalJob(gen)
338+
engine.submit(job)
339+
await job.wait_until_status("running")
340+
g = job.result()
341+
assert g.send(None) == 0
342+
assert g.send(1) == 1
343+
assert g.send(2) == 3
344+
with pytest.raises(StopIteration):
345+
g.send(3)
346+
347+
# test async generator
348+
async def gen_async():
349+
res = 0
350+
for _ in range(3):
351+
res += yield res
352+
353+
job = LocalJob(gen_async)
354+
engine.submit(job)
355+
await job.wait_until_status("running")
356+
g = job.result()
357+
assert await g.asend(None) == 0
358+
assert await g.asend(1) == 1
359+
assert await g.asend(2) == 3
360+
with pytest.raises(StopAsyncIteration):
361+
await g.asend(3)

0 commit comments

Comments
 (0)