Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 181 additions & 76 deletions src/schema/de.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
use std::{collections::HashMap, fmt};

use serde::{
de::{MapAccess, SeqAccess, Visitor},
de::{DeserializeSeed, MapAccess, SeqAccess, Visitor},
Deserialize, Deserializer,
};
use serde_json::Value;

use super::*;

/// Should be used instead of serde_json::from_value to preserve
/// information about already parsed types.
fn deserialize_schema_from_value(
value: Value,
visited: &mut HashMap<String, Schema>
) -> Result<Schema, serde_json::Error> {
SchemaDeserializer { visited_types: visited }.deserialize(value)
}

fn to_primitive(v: &str) -> Option<Schema> {
use Schema::*;
Some(match v {
Expand Down Expand Up @@ -86,6 +95,10 @@ fn remove_vec_string<E: serde::de::Error>(
}
}

fn remove_json_value(data: &mut HashMap<String, Value>, key: &str,) -> Option<Value> {
data.remove(key)
}

fn to_enum<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
Ok(Schema::Enum(Enum {
name: remove_string(data, "name")?
Expand All @@ -98,59 +111,10 @@ fn to_enum<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Sch
}))
}

fn to_map<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
let item = data
.remove("values")
.ok_or_else(|| serde::de::Error::custom("values is required in a map"))?;
let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?;
Ok(Schema::Map(Box::new(schema)))
}

fn to_schema<E: serde::de::Error>(
data: &mut HashMap<String, Value>,
key: &str,
) -> Result<Option<Schema>, E> {
let schema = data.remove(key);
schema
.map(|schema| serde_json::from_value(schema).map_err(serde::de::Error::custom))
.transpose()
}

fn to_array<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
let schema =
to_schema(data, "items")?.ok_or_else(|| E::custom("items is required in an array"))?;
Ok(Schema::Array(Box::new(schema)))
}

fn to_field<E: serde::de::Error>(data: Value) -> Result<Field, E> {
serde_json::from_value(data).map_err(E::custom)
}

fn to_vec_fields<E: serde::de::Error>(
data: &mut HashMap<String, Value>,
key: &str,
) -> Result<Vec<Field>, E> {
match data.remove(key) {
Some(s) => {
if let Value::Array(x) = s {
x.into_iter().map(to_field).collect()
} else {
Err(E::custom(format!("{} must be a string", key)))
}
}
None => Ok(vec![]),
}
}

fn to_record<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
Ok(Schema::Record(Record {
name: remove_string(data, "name")?
.ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
namespace: remove_string(data, "namespace")?,
aliases: remove_vec_string(data, "aliases")?,
doc: remove_string(data, "doc")?,
fields: to_vec_fields(data, "fields")?,
}))
fn to_field<E: serde::de::Error>(data: Value, visited: &mut HashMap<String, Schema>) -> Result<Field, E> {
FieldDeserializer {
visited_types: visited,
}.deserialize(data).map_err(E::custom)
}

fn to_fixed<E: serde::de::Error>(data: &mut HashMap<String, Value>) -> Result<Schema, E> {
Expand Down Expand Up @@ -199,9 +163,68 @@ fn to_order<E: serde::de::Error>(
.transpose()
}

struct SchemaVisitor {}
struct SchemaVisitor<'a> {
visited_types: &'a mut HashMap<String, Schema>,
}

impl<'a>SchemaVisitor<'a> {
fn to_map<E: serde::de::Error>(&mut self, data: &mut HashMap<String, Value>) -> Result<Schema, E> {
let item = data
.remove("values")
.ok_or_else(|| serde::de::Error::custom("values is required in a map"))?;
let schema: Schema = deserialize_schema_from_value(item, self.visited_types).map_err(serde::de::Error::custom)?;
Ok(Schema::Map(Box::new(schema)))
}

fn to_schema<E: serde::de::Error>(
&mut self,
data: &mut HashMap<String, Value>,
key: &str,
) -> Result<Option<Schema>, E> {
let schema = data.remove(key);
schema
.map(|schema| deserialize_schema_from_value(schema, self.visited_types).map_err(serde::de::Error::custom))
.transpose()
}

impl<'de> Visitor<'de> for SchemaVisitor {
fn to_array<E: serde::de::Error>(&mut self, data: &mut HashMap<String, Value>) -> Result<Schema, E> {
let schema =
self.to_schema(data, "items")?.ok_or_else(|| E::custom("items is required in an array"))?;
Ok(Schema::Array(Box::new(schema)))
}

fn to_vec_fields<E: serde::de::Error>(
&mut self,
data: &mut HashMap<String, Value>,
key: &str,
) -> Result<Vec<Field>, E> {
match data.remove(key) {
Some(s) => {
if let Value::Array(x) = s {
x.into_iter()
.map(|a| to_field(a.clone(), self.visited_types))
.collect()
} else {
Err(E::custom(format!("{} must be a string", key)))
}
}
None => Ok(vec![]),
}
}

fn to_record<E: serde::de::Error>(&mut self, data: &mut HashMap<String, Value>) -> Result<Schema, E> {
Ok(Schema::Record(Record {
name: remove_string(data, "name")?
.ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
namespace: remove_string(data, "namespace")?,
aliases: remove_vec_string(data, "aliases")?,
doc: remove_string(data, "doc")?,
fields: self.to_vec_fields(data, "fields")?,
}))
}
}

impl<'de, 'a> Visitor<'de> for SchemaVisitor<'a> {
type Value = Schema;

// Format a message stating what data this Visitor expects to receive.
Expand All @@ -213,7 +236,7 @@ impl<'de> Visitor<'de> for SchemaVisitor {
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(SchemaVisitor {})
deserializer.deserialize_any(self)
}

fn visit_none<E>(self) -> Result<Self::Value, E>
Expand All @@ -227,8 +250,23 @@ impl<'de> Visitor<'de> for SchemaVisitor {
where
E: serde::de::Error,
{
to_primitive(v)
.ok_or_else(|| serde::de::Error::custom("string must be a valid primitive Schema"))
let r#type = to_primitive(v);
if let Some(t) = r#type {
Ok(t)
} else {
// Previously defined types are referred using just a string
// containing the name of that previously defined type.
// That's why when the type cannot be parsed as primitive type,
// need to also check names of previously parsed types.
if let Some(t) = self.visited_types.get(&v.to_string()) {
Ok(t.clone())
} else {
for vt in self.visited_types.keys() {
println!("Visited type: {vt}");
}
Err(serde::de::Error::custom(format!("string '{v}' must be a valid primitive Schema")))
}
}
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
Expand All @@ -237,13 +275,13 @@ impl<'de> Visitor<'de> for SchemaVisitor {
{
let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(0));
while let Some(item) = seq.next_element::<Value>()? {
let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?;
let schema: Schema = deserialize_schema_from_value(item, self.visited_types).map_err(serde::de::Error::custom)?;
vec.push(schema)
}
Ok(Schema::Union(vec))
}

fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
fn visit_map<M>(mut self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
Expand Down Expand Up @@ -304,11 +342,32 @@ impl<'de> Visitor<'de> for SchemaVisitor {
})
} else {
match type_.as_ref() {
"enum" => to_enum(&mut map),
"map" => to_map(&mut map),
"array" => to_array(&mut map),
"record" => to_record(&mut map),
"fixed" => to_fixed(&mut map),
"enum" => {
let custom_type = to_enum(&mut map);
if let Ok(t) = custom_type.as_ref() {
let type_name = name_of_complex_type(t);
self.visited_types.insert(type_name, t.clone());
}
custom_type
},
"map" => self.to_map(&mut map),
"array" => self.to_array(&mut map),
"record" => {
let custom_type = self.to_record(&mut map);
if let Ok(t) = custom_type.as_ref() {
let type_name = name_of_complex_type(t);
self.visited_types.insert(type_name, t.clone());
}
custom_type
},
"fixed" => {
let custom_type = to_fixed(&mut map);
if let Ok(t) = custom_type.as_ref() {
let type_name = name_of_complex_type(t);
self.visited_types.insert(type_name, t.clone());
}
custom_type
},
other => todo!("{}", other),
}
}
Expand All @@ -320,13 +379,37 @@ impl<'de> Deserialize<'de> for Schema {
where
D: Deserializer<'de>,
{
deserializer.deserialize_option(SchemaVisitor {})
let mut visited = HashMap::new();
deserializer.deserialize_option(SchemaVisitor { visited_types: &mut visited})
}
}

struct FieldVisitor {}
/// AVRO schema allows so called ref types:
/// after a complex type (record, enum, fixed) is defined in a schema,
/// subsequent fields that have the same type can refer this
/// type by just specifying the type name.
///
/// To keep information about previously parsed types,
/// we have to use a stateful deserializer based on DeserializeSeed.
struct SchemaDeserializer<'a> {
visited_types: &'a mut HashMap<String, Schema>,
}

impl<'de, 'a> DeserializeSeed<'de> for SchemaDeserializer<'a> {
type Value = Schema;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de> {
deserializer.deserialize_option(SchemaVisitor { visited_types: self.visited_types })
}
}

impl<'de> Visitor<'de> for FieldVisitor {
struct FieldVisitor<'a> {
visited_types: &'a mut HashMap<String, Schema>,
}

impl<'de, 'a> Visitor<'de> for FieldVisitor<'a> {
type Value = Field;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
Expand All @@ -345,24 +428,46 @@ impl<'de> Visitor<'de> for FieldVisitor {
map.insert(key, value);
}

let parsed_schema =
SchemaVisitor {visited_types: self.visited_types}
.to_schema(&mut map, "type");

Ok(Field {
name: remove_string(&mut map, "name")?
.ok_or_else(|| serde::de::Error::custom("name is required in enum"))?,
doc: remove_string(&mut map, "doc")?,
schema: to_schema(&mut map, "type")?
schema: parsed_schema?
.ok_or_else(|| serde::de::Error::custom("type is required in Field"))?,
default: to_schema(&mut map, "default")?,
default: remove_json_value(&mut map, "default"),
order: to_order(&mut map, "order")?,
aliases: remove_vec_string(&mut map, "aliases")?,
})
}
}

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
struct FieldDeserializer<'a> {
visited_types: &'a mut HashMap<String, Schema>,
}

impl<'de, 'a> DeserializeSeed<'de> for FieldDeserializer<'a> {
type Value = Field;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(FieldVisitor {})
D: Deserializer<'de> {
deserializer.deserialize_map(FieldVisitor { visited_types: self.visited_types})
}
}

fn name_of_complex_type(schema: &Schema) -> String {
match schema {
Schema::Record(record) =>
format!("{}{}", record.namespace.as_ref().map(|s|format!(".{s}")).unwrap_or_default(), record.name),
Schema::Enum(r#enum) =>
format!("{}{}", r#enum.namespace.as_ref().map(|s|format!(".{s}")).unwrap_or_default(), r#enum.name),
Schema::Fixed(fixed) =>
format!("{}{}", fixed.namespace.as_ref().map(|s|format!(".{s}")).unwrap_or_default(), fixed.name),
_ =>
unreachable!("Should not be called for {:?}", schema),
}
}
5 changes: 4 additions & 1 deletion src/schema/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
//! Contains structs defining Avro's logical types

use serde_json::Value;

mod de;
mod se;

Expand Down Expand Up @@ -58,7 +61,7 @@ pub struct Field {
/// Its Schema
pub schema: Schema,
/// Its default value
pub default: Option<Schema>,
pub default: Option<Value>,
/// Its optional order
pub order: Option<Order>,
/// Its aliases
Expand Down
Loading