Skip to content

Commit 7f22084

Browse files
committed
feat(pb): support encode_to_vec for message and that with a length-delimiter
1 parent 21a121a commit 7f22084

File tree

4 files changed

+336
-3
lines changed

4 files changed

+336
-3
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.vscode
22
.idea
3-
/target
3+
**/target
44

55
*/src/test/gen/*
66

pilota/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,7 @@ harness = false
6060
[[bench]]
6161
name = "ttype"
6262
harness = false
63+
64+
[[bench]]
65+
name = "pb"
66+
harness = false

pilota/benches/pb.rs

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
#![allow(clippy::redundant_clone)]
2+
3+
use std::hint::black_box;
4+
5+
use bytes::Bytes;
6+
use criterion::{BenchmarkId, Throughput, criterion_group, criterion_main};
7+
use faststr::FastStr;
8+
use linkedbytes::LinkedBytes;
9+
use pilota::pb::{Message, encoding::EncodeLengthContext};
10+
use rand::{Rng, distr::Alphanumeric};
11+
12+
include!("../../pilota-build/test_data/protobuf/normal.rs");
13+
14+
const KB: usize = 1024;
15+
const SIZES: &[usize] = &[16, 64, 128, 512, 2 * KB, 128 * KB, 10 * 128 * KB];
16+
17+
#[inline]
18+
fn varint_len(mut n: usize) -> usize {
19+
let mut len = 1;
20+
while n >= 0x80 {
21+
n >>= 7;
22+
len += 1;
23+
}
24+
len
25+
}
26+
27+
#[inline]
28+
fn payload_len_for_total(total_size: usize) -> usize {
29+
// total = tag_len(1) + varint(payload_len) + payload_len
30+
// Try varint length from 1..=5 to find consistent payload_len.
31+
for vlen in 1..=5 {
32+
if total_size < 1 + vlen {
33+
continue;
34+
}
35+
let payload_len = total_size - 1 - vlen;
36+
if varint_len(payload_len) == vlen {
37+
return payload_len;
38+
}
39+
}
40+
panic!("cannot construct payload for total_size={}", total_size);
41+
}
42+
43+
#[inline]
44+
fn make_payload(total_size: usize) -> Vec<u8> {
45+
let len = payload_len_for_total(total_size);
46+
let mut v = Vec::with_capacity(len);
47+
// Deterministic but non-trivial content
48+
for i in 0..len {
49+
v.push(((i as u32).wrapping_mul(1315423911) as u8) ^ 0x5A);
50+
}
51+
v
52+
}
53+
54+
fn pb_bench(c: &mut criterion::Criterion) {
55+
let mut group = c.benchmark_group("PB Message Vec<u8> (BytesValue)");
56+
57+
for &total in SIZES.iter() {
58+
let msg: Vec<u8> = make_payload(total);
59+
60+
// Pre-compute capacities
61+
let mut ctx0 = EncodeLengthContext::default();
62+
let required_no_ld = msg.encoded_len(&mut ctx0);
63+
let mut ctx1 = EncodeLengthContext::default();
64+
let (_len_only, required_ld_total) = msg.encoded_len_length_delimited(&mut ctx1);
65+
66+
// Pre-encode for decode benches
67+
let mut ctx_enc = EncodeLengthContext::default();
68+
let enc = msg.encode_to_vec(&mut ctx_enc);
69+
let bytes_normal = Bytes::from(enc);
70+
71+
let mut ctx_enc_ld = EncodeLengthContext::default();
72+
let enc_ld = msg.encode_length_delimited_to_vec(&mut ctx_enc_ld);
73+
let bytes_ld = Bytes::from(enc_ld);
74+
75+
group.throughput(Throughput::Bytes(total as u64));
76+
77+
group.bench_function(BenchmarkId::new("encode", total), |b| {
78+
b.iter(|| {
79+
let mut buf = LinkedBytes::with_capacity(required_no_ld);
80+
msg.encode(&mut buf).unwrap();
81+
black_box(buf);
82+
})
83+
});
84+
85+
group.bench_function(BenchmarkId::new("encode_to_vec", total), |b| {
86+
b.iter(|| {
87+
let mut ctx = EncodeLengthContext::default();
88+
let v = msg.encode_to_vec(&mut ctx);
89+
black_box(v);
90+
})
91+
});
92+
93+
group.bench_function(BenchmarkId::new("encode_length_delimited", total), |b| {
94+
b.iter(|| {
95+
let mut ctx = EncodeLengthContext::default();
96+
let mut buf = LinkedBytes::with_capacity(required_ld_total);
97+
msg.encode_length_delimited(&mut ctx, &mut buf).unwrap();
98+
black_box(buf);
99+
})
100+
});
101+
102+
group.bench_function(
103+
BenchmarkId::new("encode_length_delimited_to_vec", total),
104+
|b| {
105+
b.iter(|| {
106+
let mut ctx = EncodeLengthContext::default();
107+
let v = msg.encode_length_delimited_to_vec(&mut ctx);
108+
black_box(v);
109+
})
110+
},
111+
);
112+
113+
group.bench_function(BenchmarkId::new("decode", total), |b| {
114+
b.iter(|| {
115+
let decoded = <Vec<u8> as Message>::decode(bytes_normal.clone()).unwrap();
116+
black_box(decoded);
117+
})
118+
});
119+
120+
group.bench_function(BenchmarkId::new("decode_length_delimited", total), |b| {
121+
b.iter(|| {
122+
let decoded =
123+
<Vec<u8> as Message>::decode_length_delimited(bytes_ld.clone()).unwrap();
124+
black_box(decoded);
125+
})
126+
});
127+
}
128+
129+
group.finish();
130+
}
131+
132+
#[inline]
133+
fn generate_random_string_pb(size: usize) -> FastStr {
134+
if size == 0 {
135+
return FastStr::empty();
136+
}
137+
rand::rng()
138+
.sample_iter(&Alphanumeric)
139+
.take(size)
140+
.map(char::from)
141+
.collect::<String>()
142+
.into()
143+
}
144+
145+
#[inline]
146+
fn prepare_obj_req_pb(size: usize) -> normal::ObjReq {
147+
let sub_msg_1 = normal::SubMessage {
148+
value: Some(generate_random_string_pb(size / 2)),
149+
};
150+
let sub_msg_2 = normal::SubMessage {
151+
value: Some(generate_random_string_pb(size / 2)),
152+
};
153+
154+
let sub_msg_list = vec![sub_msg_1.clone(), sub_msg_2.clone()];
155+
156+
let msg_key = normal::Message {
157+
uid: "".into(),
158+
value: Some(generate_random_string_pb(size / 2)),
159+
sub_messages: vec![sub_msg_1.clone()],
160+
};
161+
162+
let msg_val = normal::SubMessage {
163+
value: Some(generate_random_string_pb(size)),
164+
};
165+
166+
let msg_map_entry = normal::obj_req::MsgMapEntry {
167+
key: Some(msg_key),
168+
value: Some(msg_val),
169+
};
170+
171+
let msg_for_set_and_field = normal::Message {
172+
uid: "".into(),
173+
value: Some(generate_random_string_pb(size)),
174+
sub_messages: sub_msg_list.clone(),
175+
};
176+
177+
let mut sub_msg_list2 = vec![sub_msg_1, sub_msg_2];
178+
sub_msg_list2.extend(sub_msg_list.clone());
179+
180+
normal::ObjReq {
181+
msg: Some(msg_for_set_and_field.clone()),
182+
msg_map: vec![msg_map_entry],
183+
sub_msgs: sub_msg_list2,
184+
msg_set: vec![msg_for_set_and_field],
185+
flag_msg: "".into(),
186+
mock_cost: None,
187+
}
188+
}
189+
190+
fn pb_bench_normal(c: &mut criterion::Criterion) {
191+
let mut group = c.benchmark_group("PB Message normal::ObjReq");
192+
193+
for &size_param in SIZES.iter() {
194+
let msg = prepare_obj_req_pb(size_param);
195+
196+
// 计算真实编码长度,用于容量与吞吐统计
197+
let mut ctx0 = EncodeLengthContext::default();
198+
let required_no_ld = msg.encoded_len(&mut ctx0);
199+
let mut ctx1 = EncodeLengthContext::default();
200+
let (_len_only, required_ld_total) = msg.encoded_len_length_delimited(&mut ctx1);
201+
202+
// 预编码为 Bytes 用于解码基准
203+
let mut ctx_enc = EncodeLengthContext::default();
204+
let enc = msg.encode_to_vec(&mut ctx_enc);
205+
let bytes_normal = Bytes::from(enc);
206+
207+
let mut ctx_enc_ld = EncodeLengthContext::default();
208+
let enc_ld = msg.encode_length_delimited_to_vec(&mut ctx_enc_ld);
209+
let bytes_ld = Bytes::from(enc_ld);
210+
211+
group.throughput(Throughput::Bytes(required_no_ld as u64));
212+
213+
group.bench_function(BenchmarkId::new("encode", size_param), |b| {
214+
b.iter(|| {
215+
let mut buf = LinkedBytes::with_capacity(required_no_ld);
216+
msg.encode(&mut buf).unwrap();
217+
black_box(buf);
218+
})
219+
});
220+
221+
group.bench_function(BenchmarkId::new("encode_to_vec", size_param), |b| {
222+
b.iter(|| {
223+
let mut ctx = EncodeLengthContext::default();
224+
let v = msg.encode_to_vec(&mut ctx);
225+
black_box(v);
226+
})
227+
});
228+
229+
group.bench_function(
230+
BenchmarkId::new("encode_length_delimited", size_param),
231+
|b| {
232+
b.iter(|| {
233+
let mut ctx = EncodeLengthContext::default();
234+
let mut buf = LinkedBytes::with_capacity(required_ld_total);
235+
msg.encode_length_delimited(&mut ctx, &mut buf).unwrap();
236+
black_box(buf);
237+
})
238+
},
239+
);
240+
241+
group.bench_function(
242+
BenchmarkId::new("encode_length_delimited_to_vec", size_param),
243+
|b| {
244+
b.iter(|| {
245+
let mut ctx = EncodeLengthContext::default();
246+
let v = msg.encode_length_delimited_to_vec(&mut ctx);
247+
black_box(v);
248+
})
249+
},
250+
);
251+
252+
group.bench_function(BenchmarkId::new("decode", size_param), |b| {
253+
b.iter(|| {
254+
let decoded = <normal::ObjReq as Message>::decode(bytes_normal.clone()).unwrap();
255+
black_box(decoded);
256+
})
257+
});
258+
259+
group.bench_function(
260+
BenchmarkId::new("decode_length_delimited", size_param),
261+
|b| {
262+
b.iter(|| {
263+
let decoded =
264+
<normal::ObjReq as Message>::decode_length_delimited(bytes_ld.clone())
265+
.unwrap();
266+
black_box(decoded);
267+
})
268+
},
269+
);
270+
}
271+
272+
group.finish();
273+
}
274+
275+
criterion_group!(benches, pb_bench, pb_bench_normal);
276+
criterion_main!(benches);

pilota/src/pb/message.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ pub trait Message: Debug + Send + Sync {
6161
Ok(())
6262
}
6363

64+
/// Returns the encoded length of the message with a length delimiter.
65+
///
66+
/// (length, total)
67+
/// - length: The encoded length of the message itself
68+
/// - total: The encoded length of the message with the length delimiter
69+
fn encoded_len_length_delimited(&self, ctx: &mut EncodeLengthContext) -> (usize, usize) {
70+
let len = self.encoded_len(ctx);
71+
let required = len + encoded_len_varint(len as u64);
72+
(len, required)
73+
}
74+
6475
/// Encodes the message with a length-delimiter to a buffer.
6576
///
6677
/// An error will be returned if the buffer does not have sufficient
@@ -73,8 +84,7 @@ pub trait Message: Debug + Send + Sync {
7384
where
7485
Self: Sized,
7586
{
76-
let len = self.encoded_len(ctx);
77-
let required = len + encoded_len_varint(len as u64);
87+
let (len, required) = self.encoded_len_length_delimited(ctx);
7888
let remaining = buf.remaining_mut();
7989
if required > remaining {
8090
return Err(EncodeError::new(required, remaining));
@@ -84,6 +94,32 @@ pub trait Message: Debug + Send + Sync {
8494
Ok(())
8595
}
8696

97+
/// Encodes the message to a newly allocated buffer.
98+
fn encode_to_vec(&self, ctx: &mut EncodeLengthContext) -> Vec<u8>
99+
where
100+
Self: Sized,
101+
{
102+
let len = self.encoded_len(ctx);
103+
let required = len - ctx.zero_copy_len;
104+
105+
let mut buf = LinkedBytes::with_capacity(required);
106+
self.encode_raw(&mut buf);
107+
108+
buf.concat().to_vec()
109+
}
110+
111+
/// Encodes the message with a length-delimiter to a newly allocated buffer.
112+
fn encode_length_delimited_to_vec(&self, ctx: &mut EncodeLengthContext) -> Vec<u8>
113+
where
114+
Self: Sized,
115+
{
116+
let (len, required) = self.encoded_len_length_delimited(ctx);
117+
let mut buf = LinkedBytes::with_capacity(required);
118+
encode_varint(len as u64, &mut buf);
119+
self.encode_raw(&mut buf);
120+
buf.concat().to_vec()
121+
}
122+
87123
/// Decodes an instance of the message from a buffer.
88124
///
89125
/// The entire buffer will be consumed.
@@ -354,6 +390,23 @@ mod tests {
354390
assert_eq!(buf.len(), 2);
355391
}
356392

393+
#[test]
394+
fn test_encode_to_vec_and_zero_copy_len() {
395+
let msg = TestMessage::new(42);
396+
let mut ctx = EncodeLengthContext::default();
397+
let vec = msg.encode_to_vec(&mut ctx);
398+
assert_eq!(vec, vec![0x08, 0x2A]);
399+
}
400+
401+
#[test]
402+
fn test_encode_length_delimited_to_vec() {
403+
let msg = TestMessage::new(300);
404+
let mut ctx = EncodeLengthContext::default();
405+
// payload = [0x08, 0xAC, 0x02]; length varint for 3 is [0x03]
406+
let vec = msg.encode_length_delimited_to_vec(&mut ctx);
407+
assert_eq!(vec, vec![0x03, 0x08, 0xAC, 0x02]);
408+
}
409+
357410
#[test]
358411
fn test_message_encode_error() {
359412
let msg = TestMessage::new(42);

0 commit comments

Comments
 (0)