Skip to content

Commit 4c7e433

Browse files
committed
[feat] add Serializer for rpc server
1 parent f24294b commit 4c7e433

File tree

4 files changed

+709
-0
lines changed

4 files changed

+709
-0
lines changed

areal/scheduler/rpc/serializer.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import importlib.util
2+
import io
3+
import os
4+
import tempfile
5+
import zipfile
6+
from collections.abc import Sequence
7+
from dataclasses import asdict, is_dataclass
8+
from enum import IntEnum
9+
from inspect import isclass
10+
from typing import Any, TypeAlias
11+
12+
import numpy as np
13+
import torch
14+
from msgspec import msgpack
15+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
16+
17+
18+
class ExtensionTypeCode(IntEnum):
19+
RAW_VIEW = 1
20+
TOKENIZER = 2
21+
22+
23+
bytestr: TypeAlias = bytes | bytearray | memoryview
24+
25+
26+
def tensor_data(obj: torch.Tensor) -> memoryview:
27+
"""Extract the raw bytes from a tensor."""
28+
return memoryview(obj.detach().numpy().tobytes())
29+
30+
31+
class Serializer:
32+
"""A flexible serialization/deserialization handler for RPC communication.
33+
34+
This class provides a serialization protocol that supports:
35+
- PyTorch tensors
36+
- NumPy arrays
37+
- Hugging Face tokenizers
38+
- Dataclasses
39+
"""
40+
41+
magic_symbol = b"\x7e\x5c\x2e\x5e"
42+
43+
def __init__(self, size_threshold: int = 1024):
44+
self.size_threshold = size_threshold
45+
self.encode_buffer: list[bytestr] | None = None
46+
self.decode_buffer: Sequence[bytestr] = ()
47+
self._encoder = msgpack.Encoder(enc_hook=self._enc_hook)
48+
49+
def serialize(self, obj: Any) -> Sequence[bytestr]:
50+
try:
51+
self.encode_buffer = bufs = [b""]
52+
bufs[0] = self._encoder.encode(obj)
53+
# This `bufs` list allows us to collect direct pointers to backing
54+
# buffers of tensors and np arrays, and return them along with the
55+
# top-level encoded buffer instead of copying their data into the
56+
# new buffer.
57+
return bufs
58+
finally:
59+
self.encode_buffer = None
60+
61+
def deserialize(
62+
self, bufs: bytestr | Sequence[bytestr], decoded_type: Any | None = None
63+
) -> Any:
64+
args = () if decoded_type is None else (decoded_type,)
65+
decoder = msgpack.Decoder(
66+
*args, dec_hook=self._dec_hook, ext_hook=self._ext_hook
67+
)
68+
if isinstance(bufs, bytestr): # type: ignore
69+
return decoder.decode(bufs)
70+
71+
self.decode_buffer = bufs
72+
try:
73+
return decoder.decode(bufs[0])
74+
finally:
75+
self.decode_buffer = ()
76+
77+
def _dec_hook(self, decoded_type: type, obj: Any) -> Any:
78+
"""
79+
Given native types in `obj`, convert to type `t`.
80+
"""
81+
if isclass(decoded_type):
82+
if issubclass(decoded_type, np.ndarray):
83+
return self._decode_ndarray(obj)
84+
if issubclass(decoded_type, torch.Tensor):
85+
return self._decode_tensor(obj)
86+
if decoded_type is slice:
87+
return slice(*obj)
88+
return obj
89+
90+
def _ext_hook(self, code: int, data: memoryview) -> Any:
91+
if code == int(ExtensionTypeCode.RAW_VIEW):
92+
return data
93+
if code == int(ExtensionTypeCode.TOKENIZER):
94+
return self._decode_tokenizer(data)
95+
96+
raise NotImplementedError(f"Extension type code {code} is not supported")
97+
98+
def _enc_hook(self, obj: Any) -> Any:
99+
if is_dataclass(obj):
100+
return asdict(obj)
101+
102+
if isinstance(obj, torch.Tensor):
103+
return self._encode_tensor(obj)
104+
105+
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
106+
return self._encode_ndarray(obj)
107+
108+
if isinstance(obj, PreTrainedTokenizer | PreTrainedTokenizerFast):
109+
return self._encode_tokenizer(obj)
110+
111+
if isinstance(obj, slice):
112+
# We are assuming only int-based values will be used here.
113+
return tuple(
114+
int(v) if v is not None else None
115+
for v in (obj.start, obj.stop, obj.step)
116+
)
117+
118+
raise NotImplementedError(f"Type {type(obj)} is not supported")
119+
120+
def _decode_ndarray(self, arr: Any) -> np.ndarray:
121+
dtype, shape, data = arr
122+
# zero-copy decode. We assume the ndarray will not be kept around,
123+
# as it now locks the whole received message buffer in memory.
124+
buffer = self.decode_buffer[data] if isinstance(data, int) else data
125+
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
126+
127+
def _decode_tensor(self, arr: Any) -> torch.Tensor:
128+
dtype, shape, data = arr
129+
# Copy from inline representation, to decouple the memory storage
130+
# of the message from the original buffer. And also make Torch
131+
# not complain about a readonly memoryview.
132+
buffer = self.decode_buffer[data] if isinstance(data, int) else bytearray(data)
133+
torch_dtype = getattr(torch, dtype)
134+
assert isinstance(torch_dtype, torch.dtype)
135+
if not buffer: # torch.frombuffer doesn't like empty buffers
136+
assert 0 in shape
137+
return torch.empty(shape, dtype=torch_dtype)
138+
# Create uint8 array
139+
arr = torch.frombuffer(buffer, dtype=torch.uint8)
140+
# Convert back to proper shape & type
141+
return arr.view(torch_dtype).view(shape)
142+
143+
def _decode_tokenizer(
144+
self, blob: memoryview
145+
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
146+
blob, name_or_path = self._pop_magic_info(blob.tobytes())
147+
if blob[:4] == b"\x28\xb5\x2f\xfd": # zstd magic header
148+
import zstandard as zstd
149+
150+
blob = zstd.ZstdDecompressor().decompress(blob)
151+
152+
from transformers import AutoTokenizer
153+
154+
zip_buffer = io.BytesIO(blob)
155+
with tempfile.TemporaryDirectory() as tmpdir:
156+
with zipfile.ZipFile(zip_buffer) as zf:
157+
zf.extractall(tmpdir)
158+
cls = AutoTokenizer
159+
tokenizer = cls.from_pretrained(tmpdir)
160+
if isinstance(tokenizer, PreTrainedTokenizerFast):
161+
tokenizer.name_or_path = name_or_path.decode("utf-8")
162+
elif isinstance(tokenizer, PreTrainedTokenizer):
163+
tokenizer.name_or_path = name_or_path.decode("utf-8")
164+
return tokenizer
165+
166+
def _encode_tokenizer(
167+
self, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast
168+
) -> bytestr:
169+
zip_buffer = io.BytesIO()
170+
171+
# Ensure special tokens are preserved
172+
assert tokenizer.special_tokens_map is not None, (
173+
"Tokenizer missing special tokens map"
174+
)
175+
176+
with tempfile.TemporaryDirectory() as tmpdir:
177+
tokenizer.save_pretrained(tmpdir)
178+
total_size = sum(
179+
os.path.getsize(os.path.join(root, f))
180+
for root, _, files in os.walk(tmpdir)
181+
for f in files
182+
)
183+
184+
compression = (
185+
zipfile.ZIP_STORED if total_size < 512 * 1024 else zipfile.ZIP_DEFLATED
186+
)
187+
with zipfile.ZipFile(
188+
zip_buffer, "w", compression=compression, compresslevel=6
189+
) as zf:
190+
for root, _, files in os.walk(tmpdir):
191+
for f in files:
192+
zf.write(
193+
os.path.join(root, f),
194+
arcname=os.path.relpath(os.path.join(root, f), tmpdir),
195+
)
196+
197+
blob = zip_buffer.getvalue()
198+
199+
if len(blob) > 20 * 1024 * 1024 and importlib.util.find_spec("zstandard"):
200+
import zstandard as zstd
201+
202+
blob = zstd.ZstdCompressor(level=3).compress(blob)
203+
blob = self._append_magic_info(blob, tokenizer.name_or_path.encode("utf-8"))
204+
205+
return msgpack.Ext(int(ExtensionTypeCode.TOKENIZER), blob)
206+
207+
def _encode_ndarray(
208+
self, obj: np.ndarray
209+
) -> tuple[str, tuple[int, ...], int | memoryview]:
210+
assert self.encode_buffer is not None
211+
# If the array is non-contiguous, we need to copy it first
212+
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
213+
if not obj.shape or obj.nbytes < self.size_threshold:
214+
# Encode small arrays and scalars inline. Using this extension type
215+
# ensures we can avoid copying when decoding.
216+
data = msgpack.Ext(int(ExtensionTypeCode.RAW_VIEW), arr_data)
217+
else:
218+
# Otherwise encode index of backing buffer to avoid copy.
219+
data = len(self.encode_buffer)
220+
self.encode_buffer.append(arr_data)
221+
222+
# We serialize the ndarray as a tuple of native types.
223+
# The data is either inlined if small, or an index into a list of
224+
# backing buffers that we've stashed in `encode_buffer`.
225+
return obj.dtype.str, obj.shape, data
226+
227+
def _encode_tensor(
228+
self, obj: torch.Tensor
229+
) -> tuple[str, tuple[int, ...], int | memoryview]:
230+
assert self.encode_buffer is not None
231+
# view the tensor as a contiguous 1D array of bytes
232+
arr_data = tensor_data(obj)
233+
if obj.nbytes < self.size_threshold:
234+
# Smaller tensors are encoded inline, just like ndarrays.
235+
data = msgpack.Ext(int(ExtensionTypeCode.RAW_VIEW), arr_data)
236+
else:
237+
# Otherwise encode index of backing buffer to avoid copy.
238+
data = len(self.encode_buffer)
239+
self.encode_buffer.append(arr_data)
240+
dtype = str(obj.dtype).removeprefix("torch.")
241+
return dtype, obj.shape, data
242+
243+
def _append_magic_info(self, blob: bytes, info: bytes) -> bytes:
244+
"""
245+
Append magic symbol and info to the end of the blob with specified length.
246+
[raw blob] + [info] + [magic symbol] + [length of info] + [magic symbol]
247+
"""
248+
return (
249+
blob
250+
+ info
251+
+ self.magic_symbol
252+
+ len(info).to_bytes(4, "big")
253+
+ self.magic_symbol
254+
)
255+
256+
def _pop_magic_info(self, blob: bytes) -> tuple[bytes, bytes]:
257+
if blob[-len(self.magic_symbol) :] != self.magic_symbol:
258+
return blob, b""
259+
260+
info_len = int.from_bytes(
261+
blob[-len(self.magic_symbol) - 4 : -len(self.magic_symbol)], "big"
262+
)
263+
264+
info_end = -len(self.magic_symbol) - 4 - len(self.magic_symbol)
265+
info_start = info_end - info_len
266+
info = blob[info_start:info_end]
267+
raw_blob = blob[:info_start]
268+
return raw_blob, info
269+
270+
271+
serializer = Serializer()
272+
273+
serialize = serializer.serialize
274+
deserialize = serializer.deserialize

0 commit comments

Comments
 (0)