Skip to content

Commit 342ab2b

Browse files
authored
RUST-889 Deserialize ObjectId directly from bytes in raw deserializer (#289)
1 parent aeef692 commit 342ab2b

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

src/de/raw.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,22 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> {
271271
}
272272
}
273273

274+
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
275+
where
276+
V: serde::de::Visitor<'de>,
277+
{
278+
match self.current_type {
279+
ElementType::ObjectId => visitor.visit_borrowed_bytes(self.bytes.read_slice(12)?),
280+
_ => self.deserialize_any(visitor),
281+
}
282+
}
283+
274284
fn is_human_readable(&self) -> bool {
275285
false
276286
}
277287

278288
forward_to_deserialize_any! {
279-
bool char str bytes byte_buf unit unit_struct string
289+
bool char str byte_buf unit unit_struct string
280290
identifier newtype_struct seq tuple tuple_struct struct
281291
map ignored_any i8 i16 i32 i64 u8 u16 u32 u64 f32 f64
282292
}

src/de/serde.rs

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,66 @@ use super::raw::Decimal128Access;
3232

3333
pub(crate) struct BsonVisitor;
3434

35-
impl<'de> Deserialize<'de> for ObjectId {
36-
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35+
struct ObjectIdVisitor;
36+
37+
impl<'de> Visitor<'de> for ObjectIdVisitor {
38+
type Value = ObjectId;
39+
40+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
41+
formatter.write_str("expecting an ObjectId")
42+
}
43+
44+
#[inline]
45+
fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
3746
where
38-
D: de::Deserializer<'de>,
47+
E: serde::de::Error,
3948
{
40-
#[derive(serde::Deserialize)]
41-
#[serde(untagged)]
42-
enum OidHelper {
43-
HexString(String),
44-
Bson(Bson),
49+
ObjectId::parse_str(value).map_err(|_| {
50+
E::invalid_value(
51+
Unexpected::Str(value),
52+
&"24-character, big-endian hex string",
53+
)
54+
})
55+
}
56+
57+
#[inline]
58+
fn visit_bytes<E>(self, v: &[u8]) -> std::result::Result<Self::Value, E>
59+
where
60+
E: serde::de::Error,
61+
{
62+
let bytes: [u8; 12] = v
63+
.try_into()
64+
.map_err(|_| E::invalid_length(v.len(), &"12 bytes"))?;
65+
Ok(ObjectId::from_bytes(bytes))
66+
}
67+
68+
#[inline]
69+
fn visit_map<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
70+
where
71+
V: MapAccess<'de>,
72+
{
73+
match BsonVisitor.visit_map(&mut visitor)? {
74+
Bson::ObjectId(oid) => Ok(oid),
75+
bson => {
76+
let err = format!(
77+
"expected map containing extended-JSON formatted ObjectId, instead found {}",
78+
bson
79+
);
80+
Err(de::Error::custom(err))
81+
}
4582
}
83+
}
84+
}
4685

47-
match OidHelper::deserialize(deserializer)
48-
.map_err(|_| de::Error::custom("expected ObjectId extended document or hex string"))?
49-
{
50-
OidHelper::HexString(s) => ObjectId::parse_str(&s).map_err(de::Error::custom),
51-
OidHelper::Bson(bson) => match bson {
52-
Bson::ObjectId(oid) => Ok(oid),
53-
bson => {
54-
let err = format!(
55-
"expected objectId extended document or hex string, found {}",
56-
bson
57-
);
58-
Err(de::Error::invalid_type(Unexpected::Map, &&err[..]))
59-
}
60-
},
86+
impl<'de> Deserialize<'de> for ObjectId {
87+
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
88+
where
89+
D: serde::Deserializer<'de>,
90+
{
91+
if !deserializer.is_human_readable() {
92+
deserializer.deserialize_bytes(ObjectIdVisitor)
93+
} else {
94+
deserializer.deserialize_any(ObjectIdVisitor)
6195
}
6296
}
6397
}

0 commit comments

Comments
 (0)