|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -pub(super) const PREAMBLE_INTS_SHORT: u8 = 3; |
19 | | -pub(super) const PREAMBLE_INTS_LONG: u8 = 6; |
20 | | -pub(super) const SERIAL_VERSION: u8 = 1; |
21 | | -pub(super) const DENSITY_FAMILY_ID: u8 = 19; |
22 | | -pub(super) const FLAGS_IS_EMPTY: u8 = 1 << 2; |
| 18 | +use crate::codec::SketchBytes; |
| 19 | +use crate::codec::SketchSlice; |
| 20 | +use crate::codec::assert::ensure_preamble_longs_in; |
| 21 | +use crate::codec::assert::ensure_serial_version_is; |
| 22 | +use crate::codec::family::Family; |
| 23 | +use crate::error::Error; |
| 24 | +use crate::error::ErrorKind; |
| 25 | + |
| 26 | +const PREAMBLE_INTS_SHORT: u8 = 3; |
| 27 | +const PREAMBLE_INTS_LONG: u8 = 6; |
| 28 | +const SERIAL_VERSION: u8 = 1; |
| 29 | +const FLAGS_IS_EMPTY: u8 = 1 << 2; |
| 30 | + |
| 31 | +type Point<T> = Vec<T>; |
| 32 | +type Level<T> = Vec<Point<T>>; |
| 33 | +type Levels<T> = Vec<Level<T>>; |
| 34 | +type SerializeValue<T> = fn(&mut SketchBytes, T); |
| 35 | +type DeserializeValue<T> = fn(&mut SketchSlice<'_>) -> std::io::Result<T>; |
| 36 | + |
| 37 | +pub(super) struct DecodedSketch<T> { |
| 38 | + pub(super) k: u16, |
| 39 | + pub(super) dim: u32, |
| 40 | + pub(super) num_retained: u32, |
| 41 | + pub(super) n: u64, |
| 42 | + pub(super) levels: Levels<T>, |
| 43 | +} |
| 44 | + |
| 45 | +pub(super) trait SketchSerializationView<T> { |
| 46 | + fn is_empty(&self) -> bool; |
| 47 | + fn k(&self) -> u16; |
| 48 | + fn dim(&self) -> u32; |
| 49 | + fn num_retained(&self) -> u32; |
| 50 | + fn n(&self) -> u64; |
| 51 | + fn levels(&self) -> &[Level<T>]; |
| 52 | +} |
| 53 | + |
| 54 | +pub(super) fn serialize_f32<S: SketchSerializationView<f32>>(sketch: &S) -> Vec<u8> { |
| 55 | + serialize_inner(sketch, 4, |bytes, value| bytes.write_f32_le(value)) |
| 56 | +} |
| 57 | + |
| 58 | +pub(super) fn serialize_f64<S: SketchSerializationView<f64>>(sketch: &S) -> Vec<u8> { |
| 59 | + serialize_inner(sketch, 8, |bytes, value| bytes.write_f64_le(value)) |
| 60 | +} |
| 61 | + |
| 62 | +pub(super) fn deserialize_f32(bytes: &[u8]) -> Result<DecodedSketch<f32>, Error> { |
| 63 | + deserialize_inner(bytes, |cursor| cursor.read_f32_le()) |
| 64 | +} |
| 65 | + |
| 66 | +pub(super) fn deserialize_f64(bytes: &[u8]) -> Result<DecodedSketch<f64>, Error> { |
| 67 | + deserialize_inner(bytes, |cursor| cursor.read_f64_le()) |
| 68 | +} |
| 69 | + |
| 70 | +fn serialize_inner<T: Copy, S: SketchSerializationView<T>>( |
| 71 | + sketch: &S, |
| 72 | + value_size: usize, |
| 73 | + write_value: SerializeValue<T>, |
| 74 | +) -> Vec<u8> { |
| 75 | + let preamble_ints = if sketch.is_empty() { |
| 76 | + PREAMBLE_INTS_SHORT |
| 77 | + } else { |
| 78 | + PREAMBLE_INTS_LONG |
| 79 | + }; |
| 80 | + let mut size_bytes = preamble_ints as usize * 4; |
| 81 | + if !sketch.is_empty() { |
| 82 | + for level in sketch.levels() { |
| 83 | + size_bytes += 4 + (level.len() * sketch.dim() as usize * value_size); |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + let mut bytes = SketchBytes::with_capacity(size_bytes); |
| 88 | + bytes.write_u8(preamble_ints); |
| 89 | + bytes.write_u8(SERIAL_VERSION); |
| 90 | + bytes.write_u8(Family::DENSITY.id); |
| 91 | + let flags = if sketch.is_empty() { FLAGS_IS_EMPTY } else { 0 }; |
| 92 | + bytes.write_u8(flags); |
| 93 | + bytes.write_u16_le(sketch.k()); |
| 94 | + bytes.write_u16_le(0); |
| 95 | + bytes.write_u32_le(sketch.dim()); |
| 96 | + |
| 97 | + if sketch.is_empty() { |
| 98 | + return bytes.into_bytes(); |
| 99 | + } |
| 100 | + |
| 101 | + bytes.write_u32_le(sketch.num_retained()); |
| 102 | + bytes.write_u64_le(sketch.n()); |
| 103 | + for level in sketch.levels() { |
| 104 | + bytes.write_u32_le(level.len() as u32); |
| 105 | + for point in level { |
| 106 | + for value in point { |
| 107 | + write_value(&mut bytes, *value); |
| 108 | + } |
| 109 | + } |
| 110 | + } |
| 111 | + bytes.into_bytes() |
| 112 | +} |
| 113 | + |
| 114 | +fn deserialize_inner<T>( |
| 115 | + bytes: &[u8], |
| 116 | + read_value: DeserializeValue<T>, |
| 117 | +) -> Result<DecodedSketch<T>, Error> { |
| 118 | + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { |
| 119 | + move |_| Error::insufficient_data(tag) |
| 120 | + } |
| 121 | + |
| 122 | + let mut cursor = SketchSlice::new(bytes); |
| 123 | + let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; |
| 124 | + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; |
| 125 | + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; |
| 126 | + let flags = cursor.read_u8().map_err(make_error("flags"))?; |
| 127 | + let k = cursor.read_u16_le().map_err(make_error("k"))?; |
| 128 | + cursor.read_u16_le().map_err(make_error("unused"))?; |
| 129 | + let dim = cursor.read_u32_le().map_err(make_error("dim"))?; |
| 130 | + |
| 131 | + Family::DENSITY.validate_id(family_id)?; |
| 132 | + ensure_serial_version_is(SERIAL_VERSION, serial_version)?; |
| 133 | + if k < 2 { |
| 134 | + return Err(Error::new( |
| 135 | + ErrorKind::InvalidArgument, |
| 136 | + format!("k must be > 1. Found: {k}"), |
| 137 | + )); |
| 138 | + } |
| 139 | + |
| 140 | + let is_empty = (flags & FLAGS_IS_EMPTY) != 0; |
| 141 | + let expected_preamble = if is_empty { |
| 142 | + PREAMBLE_INTS_SHORT |
| 143 | + } else { |
| 144 | + PREAMBLE_INTS_LONG |
| 145 | + }; |
| 146 | + ensure_preamble_longs_in(&[expected_preamble], preamble_ints)?; |
| 147 | + if is_empty { |
| 148 | + return Ok(DecodedSketch { |
| 149 | + k, |
| 150 | + dim, |
| 151 | + num_retained: 0, |
| 152 | + n: 0, |
| 153 | + levels: vec![Vec::new()], |
| 154 | + }); |
| 155 | + } |
| 156 | + |
| 157 | + let num_retained = cursor.read_u32_le().map_err(make_error("num_retained"))?; |
| 158 | + let n = cursor.read_u64_le().map_err(make_error("n"))?; |
| 159 | + |
| 160 | + let mut levels = Vec::new(); |
| 161 | + let mut remaining = num_retained as i64; |
| 162 | + while remaining > 0 { |
| 163 | + let level_size = cursor.read_u32_le().map_err(make_error("level_size"))?; |
| 164 | + let mut level = Vec::with_capacity(level_size as usize); |
| 165 | + for _ in 0..level_size { |
| 166 | + let mut point = Vec::with_capacity(dim as usize); |
| 167 | + for _ in 0..dim { |
| 168 | + point.push(read_value(&mut cursor).map_err(make_error("point"))?); |
| 169 | + } |
| 170 | + level.push(point); |
| 171 | + } |
| 172 | + remaining -= level_size as i64; |
| 173 | + levels.push(level); |
| 174 | + } |
| 175 | + if remaining != 0 { |
| 176 | + return Err(Error::deserial( |
| 177 | + "invalid number of retained points while decoding density sketch", |
| 178 | + )); |
| 179 | + } |
| 180 | + |
| 181 | + Ok(DecodedSketch { |
| 182 | + k, |
| 183 | + dim, |
| 184 | + num_retained, |
| 185 | + n, |
| 186 | + levels, |
| 187 | + }) |
| 188 | +} |
0 commit comments