1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import asyncio
1516import struct
1617from io import BytesIO
1718from typing import Any
2021import numpy as np
2122
2223from ..utils import lazy_import
23- from .core import serialize , deserialize
24+ from .core import serialize_with_spawn , deserialize
2425
2526cupy = lazy_import ("cupy" , globals = globals ())
2627cudf = lazy_import ("cudf" , globals = globals ())
2728
2829DEFAULT_SERIALIZATION_VERSION = 1
30+ DEFAULT_SPAWN_THRESHOLD = 100
2931BUFFER_SIZES_NAME = "buf_sizes"
3032
3133
@@ -34,8 +36,10 @@ def __init__(self, obj: Any, compress=0):
3436 self ._obj = obj
3537 self ._compress = compress
3638
37- def _get_buffers (self ):
38- headers , buffers = serialize (self ._obj )
39+ async def _get_buffers (self ):
40+ headers , buffers = await serialize_with_spawn (
41+ self ._obj , spawn_threshold = DEFAULT_SPAWN_THRESHOLD
42+ )
3943
4044 def _is_cuda_buffer (buf ): # pragma: no cover
4145 if cupy is not None and cudf is not None :
@@ -78,7 +82,7 @@ def _is_cuda_buffer(buf): # pragma: no cover
7882 return out_buffers
7983
8084 async def run (self ):
81- return self ._get_buffers ()
85+ return await self ._get_buffers ()
8286
8387
8488MALFORMED_MSG = """\
@@ -123,8 +127,13 @@ async def _get_obj(self):
123127 buffer_sizes = header [0 ].pop (BUFFER_SIZES_NAME )
124128 # get buffers
125129 buffers = [await self ._readexactly (size ) for size in buffer_sizes ]
130+ # get num of objs
131+ num_objs = header [0 ].get ("_N" , 0 )
126132
127- return deserialize (header , buffers )
133+ if num_objs <= DEFAULT_SPAWN_THRESHOLD :
134+ return deserialize (header , buffers )
135+ else :
136+ return await asyncio .to_thread (deserialize , header , buffers )
128137
129138 async def run (self ):
130139 return await self ._get_obj ()
0 commit comments