Skip to content

Commit 1f6c3d4

Browse files
authored
Fix NotImplementedError for mo.batch when single call not implemented (#2635)
1 parent 6afd7ed commit 1f6c3d4

File tree

5 files changed

+33
-13
lines changed

5 files changed

+33
-13
lines changed

.github/workflows/core-ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ jobs:
5050
5151
source ./ci/rewrite-cov-config.sh
5252
53-
pip install git+https://github.com/mars-project/pytest-asyncio.git
5453
pip install numpy scipy cython oss2
5554
pip install -e ".[dev,extra]"
5655
pip install virtualenv flaky

.github/workflows/os-compat-ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ jobs:
3737
3838
source ./ci/rewrite-cov-config.sh
3939
40-
pip install git+https://github.com/mars-project/pytest-asyncio.git
4140
pip install numpy scipy cython oss2
4241
pip install -e ".[dev,extra]"
4342
pip install virtualenv flaky

.github/workflows/platform-ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ jobs:
5151
5252
source ./ci/rewrite-cov-config.sh
5353
54-
pip install git+https://github.com/mars-project/pytest-asyncio.git
5554
pip install numpy scipy cython
5655
5756
pip install -e ".[dev,extra]"

mars/oscar/batch.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class _ExtensibleCallable:
8282
func: Callable
8383
batch_func: Optional[Callable]
8484
is_async: bool
85+
has_single_func: bool
8586

8687
def __call__(self, *args, **kwargs):
8788
if self.is_async:
@@ -91,20 +92,26 @@ def __call__(self, *args, **kwargs):
9192

9293
async def _async_call(self, *args, **kwargs):
9394
try:
94-
return await self.func(*args, **kwargs)
95+
if self.has_single_func:
96+
return await self.func(*args, **kwargs)
9597
except NotImplementedError:
96-
if self.batch_func:
97-
ret = await self.batch_func([args], [kwargs])
98-
return None if ret is None else ret[0]
99-
raise
98+
self.has_single_func = False
99+
100+
if self.batch_func is not None:
101+
ret = await self.batch_func([args], [kwargs])
102+
return None if ret is None else ret[0]
103+
raise NotImplementedError
100104

101105
def _sync_call(self, *args, **kwargs):
102106
try:
103-
return self.func(*args, **kwargs)
107+
if self.has_single_func:
108+
return self.func(*args, **kwargs)
104109
except NotImplementedError:
105-
if self.batch_func:
106-
return self.batch_func([args], [kwargs])[0]
107-
raise
110+
self.has_single_func = False
111+
112+
if self.batch_func is not None:
113+
return self.batch_func([args], [kwargs])[0]
114+
raise NotImplementedError
108115

109116

110117
class _ExtensibleWrapper(_ExtensibleCallable):
@@ -119,6 +126,7 @@ def __init__(
119126
self.batch_func = batch_func
120127
self.bind_func = bind_func
121128
self.is_async = is_async
129+
self.has_single_func = True
122130

123131
@staticmethod
124132
def delay(*args, **kwargs):
@@ -138,7 +146,7 @@ async def _async_batch(self, *delays):
138146
# will be more efficient
139147
if len(delays) == 1:
140148
d = delays[0]
141-
return [await self.func(*d.args, **d.kwargs)]
149+
return [await self._async_call(*d.args, **d.kwargs)]
142150
elif self.batch_func:
143151
args_list, kwargs_list = self._gen_args_kwargs_list(delays)
144152
return await self.batch_func(args_list, kwargs_list)
@@ -184,6 +192,7 @@ def __init__(self, func: Callable):
184192
self.batch_func = None
185193
self.bind_func = build_args_binder(func, remove_self=True)
186194
self.is_async = asyncio.iscoroutinefunction(self.func)
195+
self.has_single_func = True
187196

188197
def batch(self, func: Callable):
189198
self.batch_func = func

mars/oscar/tests/test_batch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def __init__(self):
126126
self.arg_list = []
127127
self.kwarg_list = []
128128

129+
@extensible
130+
@_wrap_async(use_async)
131+
def not_implemented_method(self, *args, **kw):
132+
raise NotImplementedError
133+
129134
@extensible
130135
@_wrap_async(use_async)
131136
def method(self, *args, **kwargs):
@@ -141,6 +146,11 @@ def method(self, args_list, kwargs_list):
141146
if use_async:
142147
assert asyncio.iscoroutinefunction(TestClass.method)
143148

149+
test_inst = TestClass()
150+
ret = test_inst.method.batch(test_inst.method.delay(12))
151+
ret = await ret if use_async else ret
152+
assert ret == [1]
153+
144154
test_inst = TestClass()
145155
ret = test_inst.method.batch(test_inst.method.delay(12), test_inst.method.delay(10))
146156
ret = await ret if use_async else ret
@@ -149,6 +159,10 @@ def method(self, args_list, kwargs_list):
149159
assert test_inst.kwarg_list == [{}, {}]
150160

151161
test_inst = TestClass()
162+
for _ in range(2):
163+
with pytest.raises(NotImplementedError):
164+
ret = test_inst.not_implemented_method()
165+
await ret if use_async else ret
152166
ret = test_inst.method(12, kwarg=34)
153167
ret = await ret if use_async else ret
154168
assert ret == 1

0 commit comments

Comments
 (0)