File tree Expand file tree Collapse file tree 8 files changed +61
-5
lines changed Expand file tree Collapse file tree 8 files changed +61
-5
lines changed Original file line number Diff line number Diff line change 1
1
from .core import Engine , EngineSetting
2
2
from .job import LocalJob , ThreadJob , ProcessJob
3
3
4
- __version__ = '0.2.7 '
4
+ __version__ = '0.2.8 '
5
5
6
6
__all__ = [
7
7
'Engine' , 'EngineSetting' ,
Original file line number Diff line number Diff line change 1
1
import functools
2
+ from inspect import iscoroutinefunction
2
3
3
4
from dask .distributed import Client , LocalCluster
4
5
5
6
from .base import Job
6
- from .utils import create_generator_wrapper
7
+ from .utils import create_generator_wrapper , run_async_func
7
8
from ..utils import PortManager
8
9
9
10
@@ -58,6 +59,8 @@ async def run_function(self):
58
59
"""Run job with Dask."""
59
60
client = self .engine .dask_client
60
61
func = functools .partial (self .func , * self .args , ** self .kwargs )
62
+ if iscoroutinefunction (func ):
63
+ func = functools .partial (run_async_func , func )
61
64
fut = client .submit (func )
62
65
self ._executor = fut
63
66
result = await fut
Original file line number Diff line number Diff line change
1
+ from inspect import iscoroutinefunction
2
+
1
3
from .base import Job
2
4
from .utils import create_generator_wrapper
3
5
4
6
5
7
class LocalJob (Job ):
6
8
async def run_function (self ):
7
9
"""Run job in local thread."""
8
- res = self .func (* self .args , ** self .kwargs )
10
+ if iscoroutinefunction (self .func ):
11
+ res = await self .func (* self .args , ** self .kwargs )
12
+ else :
13
+ res = self .func (* self .args , ** self .kwargs )
9
14
return res
10
15
11
16
async def run_generator (self ):
Original file line number Diff line number Diff line change 1
1
import asyncio
2
2
import functools
3
+ from inspect import iscoroutinefunction
3
4
4
5
from loky .process_executor import ProcessPoolExecutor
5
6
6
7
from .base import Job
7
- from .utils import _gen_initializer , create_generator_wrapper
8
+ from .utils import (
9
+ _gen_initializer , create_generator_wrapper , run_async_func
10
+ )
8
11
9
12
10
13
class ProcessJob (Job ):
@@ -45,6 +48,8 @@ def release_resource(self) -> bool:
45
48
async def run_function (self ):
46
49
"""Run job in process pool."""
47
50
func = functools .partial (self .func , * self .args , ** self .kwargs )
51
+ if iscoroutinefunction (func ):
52
+ func = functools .partial (run_async_func , func )
48
53
self ._executor = ProcessPoolExecutor (1 )
49
54
loop = asyncio .get_running_loop ()
50
55
fut = loop .run_in_executor (self ._executor , func )
Original file line number Diff line number Diff line change 1
1
import asyncio
2
2
import functools
3
+ from inspect import iscoroutinefunction
3
4
from concurrent .futures import ThreadPoolExecutor
4
5
5
6
from .base import Job
6
- from .utils import _gen_initializer , create_generator_wrapper
7
+ from .utils import (
8
+ _gen_initializer , create_generator_wrapper , run_async_func
9
+ )
7
10
8
11
9
12
class ThreadJob (Job ):
@@ -44,6 +47,8 @@ def release_resource(self) -> bool:
44
47
async def run_function (self ):
45
48
"""Run job in thread pool."""
46
49
func = functools .partial (self .func , * self .args , ** self .kwargs )
50
+ if iscoroutinefunction (func ):
51
+ func = functools .partial (run_async_func , func )
47
52
self ._executor = ThreadPoolExecutor (1 )
48
53
loop = asyncio .get_running_loop ()
49
54
fut = loop .run_in_executor (self ._executor , func )
Original file line number Diff line number Diff line change @@ -169,3 +169,7 @@ def create_generator_wrapper(
169
169
return AsyncGeneratorWrapper (job , fut )
170
170
else :
171
171
return SyncGeneratorWrapper (job , fut )
172
+
173
+
174
+ def run_async_func (func , * args , ** kwargs ):
175
+ return asyncio .run (func (* args , ** kwargs ))
Original file line number Diff line number Diff line change @@ -104,3 +104,24 @@ async def gen():
104
104
assert x == i
105
105
i += 1
106
106
assert job .status == "done"
107
+
108
+
109
+ @pytest .mark .asyncio
110
+ async def test_dask_async_func ():
111
+ port = PortManager .find_free_port ()
112
+ cluster = LocalCluster (
113
+ dashboard_address = f":{ port } " ,
114
+ asynchronous = True ,
115
+ processes = False ,
116
+ )
117
+ client = Client (cluster )
118
+ engine = Engine ()
119
+ engine .dask_client = client
120
+
121
+ async def async_func (x ):
122
+ return x + 1
123
+
124
+ job = DaskJob (async_func , (1 ,))
125
+ await engine .submit_async (job )
126
+ await job .wait_until_status ("done" )
127
+ assert job .result () == 2
Original file line number Diff line number Diff line change @@ -359,3 +359,16 @@ async def gen_async():
359
359
assert await g .asend (2 ) == 3
360
360
with pytest .raises (StopAsyncIteration ):
361
361
await g .asend (3 )
362
+
363
+
364
+ @pytest .mark .asyncio
365
+ async def test_async_func_job ():
366
+ with Engine () as engine :
367
+ async def async_func (x ):
368
+ return x + 1
369
+
370
+ for job_cls in [LocalJob , ThreadJob , ProcessJob ]:
371
+ job = job_cls (async_func , (1 ,))
372
+ await engine .submit_async (job )
373
+ await job .wait_until_status ("done" )
374
+ assert job .result () == 2
You can’t perform that action at this time.
0 commit comments