Skip to content

Commit 6b9a653

Browse files
committed
refactor(density): align serde boundaries and family id checks
1 parent 45730e2 commit 6b9a653

File tree

4 files changed

+241
-208
lines changed

4 files changed

+241
-208
lines changed

datasketches/src/codec/family.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ impl Family {
7373
max_pre_longs: 2,
7474
};
7575

76+
/// Density sketch for streaming density estimation.
77+
pub const DENSITY: Family = Family {
78+
id: 19,
79+
name: "DENSITY",
80+
min_pre_longs: 3,
81+
max_pre_longs: 6,
82+
};
83+
7684
/// T-Digest for estimating quantiles and ranks.
7785
pub const TDIGEST: Family = Family {
7886
id: 20,

datasketches/src/density/serialization.rs

Lines changed: 171 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,174 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
pub(super) const PREAMBLE_INTS_SHORT: u8 = 3;
19-
pub(super) const PREAMBLE_INTS_LONG: u8 = 6;
20-
pub(super) const SERIAL_VERSION: u8 = 1;
21-
pub(super) const DENSITY_FAMILY_ID: u8 = 19;
22-
pub(super) const FLAGS_IS_EMPTY: u8 = 1 << 2;
18+
use crate::codec::SketchBytes;
19+
use crate::codec::SketchSlice;
20+
use crate::codec::assert::ensure_preamble_longs_in;
21+
use crate::codec::assert::ensure_serial_version_is;
22+
use crate::codec::family::Family;
23+
use crate::error::Error;
24+
use crate::error::ErrorKind;
25+
26+
const PREAMBLE_INTS_SHORT: u8 = 3;
27+
const PREAMBLE_INTS_LONG: u8 = 6;
28+
const SERIAL_VERSION: u8 = 1;
29+
const FLAGS_IS_EMPTY: u8 = 1 << 2;
30+
31+
type Point<T> = Vec<T>;
32+
type Level<T> = Vec<Point<T>>;
33+
type Levels<T> = Vec<Level<T>>;
34+
type SerializeValue<T> = fn(&mut SketchBytes, T);
35+
type DeserializeValue<T> = fn(&mut SketchSlice<'_>) -> std::io::Result<T>;
36+
37+
pub(super) struct DecodedSketch<T> {
38+
pub(super) k: u16,
39+
pub(super) dim: u32,
40+
pub(super) num_retained: u32,
41+
pub(super) n: u64,
42+
pub(super) levels: Levels<T>,
43+
}
44+
45+
pub(super) trait SketchSerializationView<T> {
46+
fn is_empty(&self) -> bool;
47+
fn k(&self) -> u16;
48+
fn dim(&self) -> u32;
49+
fn num_retained(&self) -> u32;
50+
fn n(&self) -> u64;
51+
fn levels(&self) -> &[Level<T>];
52+
}
53+
54+
pub(super) fn serialize_f32<S: SketchSerializationView<f32>>(sketch: &S) -> Vec<u8> {
55+
serialize_inner(sketch, 4, |bytes, value| bytes.write_f32_le(value))
56+
}
57+
58+
pub(super) fn serialize_f64<S: SketchSerializationView<f64>>(sketch: &S) -> Vec<u8> {
59+
serialize_inner(sketch, 8, |bytes, value| bytes.write_f64_le(value))
60+
}
61+
62+
pub(super) fn deserialize_f32(bytes: &[u8]) -> Result<DecodedSketch<f32>, Error> {
63+
deserialize_inner(bytes, |cursor| cursor.read_f32_le())
64+
}
65+
66+
pub(super) fn deserialize_f64(bytes: &[u8]) -> Result<DecodedSketch<f64>, Error> {
67+
deserialize_inner(bytes, |cursor| cursor.read_f64_le())
68+
}
69+
70+
fn serialize_inner<T: Copy, S: SketchSerializationView<T>>(
71+
sketch: &S,
72+
value_size: usize,
73+
write_value: SerializeValue<T>,
74+
) -> Vec<u8> {
75+
let preamble_ints = if sketch.is_empty() {
76+
PREAMBLE_INTS_SHORT
77+
} else {
78+
PREAMBLE_INTS_LONG
79+
};
80+
let mut size_bytes = preamble_ints as usize * 4;
81+
if !sketch.is_empty() {
82+
for level in sketch.levels() {
83+
size_bytes += 4 + (level.len() * sketch.dim() as usize * value_size);
84+
}
85+
}
86+
87+
let mut bytes = SketchBytes::with_capacity(size_bytes);
88+
bytes.write_u8(preamble_ints);
89+
bytes.write_u8(SERIAL_VERSION);
90+
bytes.write_u8(Family::DENSITY.id);
91+
let flags = if sketch.is_empty() { FLAGS_IS_EMPTY } else { 0 };
92+
bytes.write_u8(flags);
93+
bytes.write_u16_le(sketch.k());
94+
bytes.write_u16_le(0);
95+
bytes.write_u32_le(sketch.dim());
96+
97+
if sketch.is_empty() {
98+
return bytes.into_bytes();
99+
}
100+
101+
bytes.write_u32_le(sketch.num_retained());
102+
bytes.write_u64_le(sketch.n());
103+
for level in sketch.levels() {
104+
bytes.write_u32_le(level.len() as u32);
105+
for point in level {
106+
for value in point {
107+
write_value(&mut bytes, *value);
108+
}
109+
}
110+
}
111+
bytes.into_bytes()
112+
}
113+
114+
fn deserialize_inner<T>(
115+
bytes: &[u8],
116+
read_value: DeserializeValue<T>,
117+
) -> Result<DecodedSketch<T>, Error> {
118+
fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error {
119+
move |_| Error::insufficient_data(tag)
120+
}
121+
122+
let mut cursor = SketchSlice::new(bytes);
123+
let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?;
124+
let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?;
125+
let family_id = cursor.read_u8().map_err(make_error("family_id"))?;
126+
let flags = cursor.read_u8().map_err(make_error("flags"))?;
127+
let k = cursor.read_u16_le().map_err(make_error("k"))?;
128+
cursor.read_u16_le().map_err(make_error("unused"))?;
129+
let dim = cursor.read_u32_le().map_err(make_error("dim"))?;
130+
131+
Family::DENSITY.validate_id(family_id)?;
132+
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
133+
if k < 2 {
134+
return Err(Error::new(
135+
ErrorKind::InvalidArgument,
136+
format!("k must be > 1. Found: {k}"),
137+
));
138+
}
139+
140+
let is_empty = (flags & FLAGS_IS_EMPTY) != 0;
141+
let expected_preamble = if is_empty {
142+
PREAMBLE_INTS_SHORT
143+
} else {
144+
PREAMBLE_INTS_LONG
145+
};
146+
ensure_preamble_longs_in(&[expected_preamble], preamble_ints)?;
147+
if is_empty {
148+
return Ok(DecodedSketch {
149+
k,
150+
dim,
151+
num_retained: 0,
152+
n: 0,
153+
levels: vec![Vec::new()],
154+
});
155+
}
156+
157+
let num_retained = cursor.read_u32_le().map_err(make_error("num_retained"))?;
158+
let n = cursor.read_u64_le().map_err(make_error("n"))?;
159+
160+
let mut levels = Vec::new();
161+
let mut remaining = num_retained as i64;
162+
while remaining > 0 {
163+
let level_size = cursor.read_u32_le().map_err(make_error("level_size"))?;
164+
let mut level = Vec::with_capacity(level_size as usize);
165+
for _ in 0..level_size {
166+
let mut point = Vec::with_capacity(dim as usize);
167+
for _ in 0..dim {
168+
point.push(read_value(&mut cursor).map_err(make_error("point"))?);
169+
}
170+
level.push(point);
171+
}
172+
remaining -= level_size as i64;
173+
levels.push(level);
174+
}
175+
if remaining != 0 {
176+
return Err(Error::deserial(
177+
"invalid number of retained points while decoding density sketch",
178+
));
179+
}
180+
181+
Ok(DecodedSketch {
182+
k,
183+
dim,
184+
num_retained,
185+
n,
186+
levels,
187+
})
188+
}

0 commit comments

Comments
 (0)