Skip to content

Commit a58906f

Browse files
committed
support async function
1 parent 669d71f commit a58906f

File tree

8 files changed

+61
-5
lines changed

8 files changed

+61
-5
lines changed

executor/engine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .core import Engine, EngineSetting
22
from .job import LocalJob, ThreadJob, ProcessJob
33

4-
__version__ = '0.2.7'
4+
__version__ = '0.2.8'
55

66
__all__ = [
77
'Engine', 'EngineSetting',

executor/engine/job/dask.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import functools
2+
from inspect import iscoroutinefunction
23

34
from dask.distributed import Client, LocalCluster
45

56
from .base import Job
6-
from .utils import create_generator_wrapper
7+
from .utils import create_generator_wrapper, run_async_func
78
from ..utils import PortManager
89

910

@@ -58,6 +59,8 @@ async def run_function(self):
5859
"""Run job with Dask."""
5960
client = self.engine.dask_client
6061
func = functools.partial(self.func, *self.args, **self.kwargs)
62+
if iscoroutinefunction(func):
63+
func = functools.partial(run_async_func, func)
6164
fut = client.submit(func)
6265
self._executor = fut
6366
result = await fut

executor/engine/job/local.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
from inspect import iscoroutinefunction
2+
13
from .base import Job
24
from .utils import create_generator_wrapper
35

46

57
class LocalJob(Job):
68
async def run_function(self):
79
"""Run job in local thread."""
8-
res = self.func(*self.args, **self.kwargs)
10+
if iscoroutinefunction(self.func):
11+
res = await self.func(*self.args, **self.kwargs)
12+
else:
13+
res = self.func(*self.args, **self.kwargs)
914
return res
1015

1116
async def run_generator(self):

executor/engine/job/process.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import asyncio
22
import functools
3+
from inspect import iscoroutinefunction
34

45
from loky.process_executor import ProcessPoolExecutor
56

67
from .base import Job
7-
from .utils import _gen_initializer, create_generator_wrapper
8+
from .utils import (
9+
_gen_initializer, create_generator_wrapper, run_async_func
10+
)
811

912

1013
class ProcessJob(Job):
@@ -45,6 +48,8 @@ def release_resource(self) -> bool:
4548
async def run_function(self):
4649
"""Run job in process pool."""
4750
func = functools.partial(self.func, *self.args, **self.kwargs)
51+
if iscoroutinefunction(func):
52+
func = functools.partial(run_async_func, func)
4853
self._executor = ProcessPoolExecutor(1)
4954
loop = asyncio.get_running_loop()
5055
fut = loop.run_in_executor(self._executor, func)

executor/engine/job/thread.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import asyncio
22
import functools
3+
from inspect import iscoroutinefunction
34
from concurrent.futures import ThreadPoolExecutor
45

56
from .base import Job
6-
from .utils import _gen_initializer, create_generator_wrapper
7+
from .utils import (
8+
_gen_initializer, create_generator_wrapper, run_async_func
9+
)
710

811

912
class ThreadJob(Job):
@@ -44,6 +47,8 @@ def release_resource(self) -> bool:
4447
async def run_function(self):
4548
"""Run job in thread pool."""
4649
func = functools.partial(self.func, *self.args, **self.kwargs)
50+
if iscoroutinefunction(func):
51+
func = functools.partial(run_async_func, func)
4752
self._executor = ThreadPoolExecutor(1)
4853
loop = asyncio.get_running_loop()
4954
fut = loop.run_in_executor(self._executor, func)

executor/engine/job/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,7 @@ def create_generator_wrapper(
169169
return AsyncGeneratorWrapper(job, fut)
170170
else:
171171
return SyncGeneratorWrapper(job, fut)
172+
173+
174+
def run_async_func(func, *args, **kwargs):
175+
return asyncio.run(func(*args, **kwargs))

tests/test_dask_job.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,24 @@ async def gen():
104104
assert x == i
105105
i += 1
106106
assert job.status == "done"
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_dask_async_func():
111+
port = PortManager.find_free_port()
112+
cluster = LocalCluster(
113+
dashboard_address=f":{port}",
114+
asynchronous=True,
115+
processes=False,
116+
)
117+
client = Client(cluster)
118+
engine = Engine()
119+
engine.dask_client = client
120+
121+
async def async_func(x):
122+
return x + 1
123+
124+
job = DaskJob(async_func, (1,))
125+
await engine.submit_async(job)
126+
await job.wait_until_status("done")
127+
assert job.result() == 2

tests/test_job.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,16 @@ async def gen_async():
359359
assert await g.asend(2) == 3
360360
with pytest.raises(StopAsyncIteration):
361361
await g.asend(3)
362+
363+
364+
@pytest.mark.asyncio
365+
async def test_async_func_job():
366+
with Engine() as engine:
367+
async def async_func(x):
368+
return x + 1
369+
370+
for job_cls in [LocalJob, ThreadJob, ProcessJob]:
371+
job = job_cls(async_func, (1,))
372+
await engine.submit_async(job)
373+
await job.wait_until_status("done")
374+
assert job.result() == 2

0 commit comments

Comments
 (0)