Skip to content

Commit a73970b

Browse files
authored
Allow spawning serialization to threads for large objects (#2944)
1 parent 1949685 commit a73970b

File tree

10 files changed

+254
-88
lines changed

10 files changed

+254
-88
lines changed

mars/core/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from functools import wraps
1616
from typing import Dict
1717

18-
from ..serialization.core import Placeholder, short_id
18+
from ..serialization.core import Placeholder, fast_id
1919
from ..serialization.serializables import Serializable, StringField
2020
from ..serialization.serializables.core import SerializableSerializer
2121
from ..utils import tokenize
@@ -123,7 +123,7 @@ def buffered_base(func):
123123
def wrapped(self, obj: Base, context: Dict):
124124
obj_id = (obj.key, obj.id)
125125
if obj_id in context:
126-
return Placeholder(short_id(context[obj_id]))
126+
return Placeholder(fast_id(context[obj_id]))
127127
else:
128128
context[obj_id] = obj
129129
return func(self, obj, context)

mars/oscar/backends/message.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ class DeserializeMessageFailed(RuntimeError):
515515

516516

517517
cdef class MessageSerializer(Serializer):
518-
serializer_id = 56951
518+
serializer_id = 32105
519519

520520
cpdef serial(self, object obj, dict context):
521521
cdef _MessageBase msg = <_MessageBase>obj

mars/oscar/batch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def _gen_args_kwargs_list(delays):
143143
async def _async_batch(self, *delays):
144144
# when there is only one call in batch, calling one-pass method
145145
# will be more efficient
146-
if len(delays) == 1:
146+
if len(delays) == 0:
147+
return []
148+
elif len(delays) == 1:
147149
d = delays[0]
148150
return [await self._async_call(*d.args, **d.kwargs)]
149151
elif self.batch_func:
@@ -162,7 +164,9 @@ async def _async_batch(self, *delays):
162164
return await asyncio.gather(*tasks)
163165

164166
def _sync_batch(self, *delays):
165-
if self.batch_func:
167+
if delays == 0:
168+
return []
169+
elif self.batch_func:
166170
args_list, kwargs_list = self._gen_args_kwargs_list(delays)
167171
return self.batch_func(args_list, kwargs_list)
168172
else:

mars/oscar/tests/test_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def method(self, args_list, kwargs_list):
146146
if use_async:
147147
assert asyncio.iscoroutinefunction(TestClass.method)
148148

149+
test_inst = TestClass()
150+
ret = test_inst.method.batch()
151+
ret = await ret if use_async else ret
152+
assert ret == []
153+
149154
test_inst = TestClass()
150155
ret = test_inst.method.batch(test_inst.method.delay(12))
151156
ret = await ret if use_async else ret

mars/serialization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from .aio import AioSerializer, AioDeserializer
16-
from .core import serialize, deserialize, Serializer
16+
from .core import serialize, serialize_with_spawn, deserialize, Serializer
1717

1818
from . import arrow, cuda, numpy, scipy, mars_objects, ray, exception
1919

mars/serialization/aio.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import struct
1617
from io import BytesIO
1718
from typing import Any
@@ -20,12 +21,13 @@
2021
import numpy as np
2122

2223
from ..utils import lazy_import
23-
from .core import serialize, deserialize
24+
from .core import serialize_with_spawn, deserialize
2425

2526
cupy = lazy_import("cupy", globals=globals())
2627
cudf = lazy_import("cudf", globals=globals())
2728

2829
DEFAULT_SERIALIZATION_VERSION = 1
30+
DEFAULT_SPAWN_THRESHOLD = 100
2931
BUFFER_SIZES_NAME = "buf_sizes"
3032

3133

@@ -34,8 +36,10 @@ def __init__(self, obj: Any, compress=0):
3436
self._obj = obj
3537
self._compress = compress
3638

37-
def _get_buffers(self):
38-
headers, buffers = serialize(self._obj)
39+
async def _get_buffers(self):
40+
headers, buffers = await serialize_with_spawn(
41+
self._obj, spawn_threshold=DEFAULT_SPAWN_THRESHOLD
42+
)
3943

4044
def _is_cuda_buffer(buf): # pragma: no cover
4145
if cupy is not None and cudf is not None:
@@ -78,7 +82,7 @@ def _is_cuda_buffer(buf): # pragma: no cover
7882
return out_buffers
7983

8084
async def run(self):
81-
return self._get_buffers()
85+
return await self._get_buffers()
8286

8387

8488
MALFORMED_MSG = """\
@@ -123,8 +127,13 @@ async def _get_obj(self):
123127
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
124128
# get buffers
125129
buffers = [await self._readexactly(size) for size in buffer_sizes]
130+
# get num of objs
131+
num_objs = header[0].get("_N", 0)
126132

127-
return deserialize(header, buffers)
133+
if num_objs <= DEFAULT_SPAWN_THRESHOLD:
134+
return deserialize(header, buffers)
135+
else:
136+
return await asyncio.to_thread(deserialize, header, buffers)
128137

129138
async def run(self):
130139
return await self._get_obj()

mars/serialization/core.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from concurrent.futures import Executor
1516
from typing import Any, Callable, Dict, List, Tuple
1617

1718
def buffered(func: Callable) -> Callable: ...
18-
def short_id(obj: Any) -> int: ...
19+
def fast_id(obj: Any) -> int: ...
1920

2021
class Serializer:
2122
serializer_id: int
@@ -42,4 +43,10 @@ class Placeholder:
4243
def __eq__(self, other): ...
4344

4445
def serialize(obj: Any, context: Dict = None): ...
46+
async def serialize_with_spawn(
47+
obj: Any,
48+
context: Dict = None,
49+
spawn_threshold: int = 100,
50+
executor: Executor = None,
51+
): ...
4552
def deserialize(headers: List, buffers: List, context: Dict = None): ...

0 commit comments

Comments
 (0)