diff --git a/src/schema/de.rs b/src/schema/de.rs index 356c842..a751ffd 100644 --- a/src/schema/de.rs +++ b/src/schema/de.rs @@ -1,13 +1,22 @@ use std::{collections::HashMap, fmt}; use serde::{ - de::{MapAccess, SeqAccess, Visitor}, + de::{DeserializeSeed, MapAccess, SeqAccess, Visitor}, Deserialize, Deserializer, }; use serde_json::Value; use super::*; +/// Should be used instead of serde_json::from_value to preserve +/// information about already parsed types. +fn deserialize_schema_from_value( + value: Value, + visited: &mut HashMap +) -> Result { + SchemaDeserializer { visited_types: visited }.deserialize(value) +} + fn to_primitive(v: &str) -> Option { use Schema::*; Some(match v { @@ -86,6 +95,10 @@ fn remove_vec_string( } } +fn remove_json_value(data: &mut HashMap, key: &str,) -> Option { + data.remove(key) +} + fn to_enum(data: &mut HashMap) -> Result { Ok(Schema::Enum(Enum { name: remove_string(data, "name")? @@ -98,59 +111,10 @@ fn to_enum(data: &mut HashMap) -> Result(data: &mut HashMap) -> Result { - let item = data - .remove("values") - .ok_or_else(|| serde::de::Error::custom("values is required in a map"))?; - let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?; - Ok(Schema::Map(Box::new(schema))) -} - -fn to_schema( - data: &mut HashMap, - key: &str, -) -> Result, E> { - let schema = data.remove(key); - schema - .map(|schema| serde_json::from_value(schema).map_err(serde::de::Error::custom)) - .transpose() -} - -fn to_array(data: &mut HashMap) -> Result { - let schema = - to_schema(data, "items")?.ok_or_else(|| E::custom("items is required in an array"))?; - Ok(Schema::Array(Box::new(schema))) -} - -fn to_field(data: Value) -> Result { - serde_json::from_value(data).map_err(E::custom) -} - -fn to_vec_fields( - data: &mut HashMap, - key: &str, -) -> Result, E> { - match data.remove(key) { - Some(s) => { - if let Value::Array(x) = s { - x.into_iter().map(to_field).collect() - } else { - Err(E::custom(format!("{} must be a string", key))) - } - } - None => Ok(vec![]), - } -} - -fn to_record(data: &mut HashMap) -> Result { - Ok(Schema::Record(Record { - name: remove_string(data, "name")? - .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?, - namespace: remove_string(data, "namespace")?, - aliases: remove_vec_string(data, "aliases")?, - doc: remove_string(data, "doc")?, - fields: to_vec_fields(data, "fields")?, - })) +fn to_field(data: Value, visited: &mut HashMap) -> Result { + FieldDeserializer { + visited_types: visited, + }.deserialize(data).map_err(E::custom) } fn to_fixed(data: &mut HashMap) -> Result { @@ -199,9 +163,68 @@ fn to_order( .transpose() } -struct SchemaVisitor {} +struct SchemaVisitor<'a> { + visited_types: &'a mut HashMap, +} + +impl<'a>SchemaVisitor<'a> { + fn to_map(&mut self, data: &mut HashMap) -> Result { + let item = data + .remove("values") + .ok_or_else(|| serde::de::Error::custom("values is required in a map"))?; + let schema: Schema = deserialize_schema_from_value(item, self.visited_types).map_err(serde::de::Error::custom)?; + Ok(Schema::Map(Box::new(schema))) + } + + fn to_schema( + &mut self, + data: &mut HashMap, + key: &str, + ) -> Result, E> { + let schema = data.remove(key); + schema + .map(|schema| deserialize_schema_from_value(schema, self.visited_types).map_err(serde::de::Error::custom)) + .transpose() + } -impl<'de> Visitor<'de> for SchemaVisitor { + fn to_array(&mut self, data: &mut HashMap) -> Result { + let schema = + self.to_schema(data, "items")?.ok_or_else(|| E::custom("items is required in an array"))?; + Ok(Schema::Array(Box::new(schema))) + } + + fn to_vec_fields( + &mut self, + data: &mut HashMap, + key: &str, + ) -> Result, E> { + match data.remove(key) { + Some(s) => { + if let Value::Array(x) = s { + x.into_iter() + .map(|a| to_field(a.clone(), self.visited_types)) + .collect() + } else { + Err(E::custom(format!("{} must be a string", key))) + } + } + None => Ok(vec![]), + } + } + + fn to_record(&mut self, data: &mut HashMap) -> Result { + Ok(Schema::Record(Record { + name: remove_string(data, "name")? + .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?, + namespace: remove_string(data, "namespace")?, + aliases: remove_vec_string(data, "aliases")?, + doc: remove_string(data, "doc")?, + fields: self.to_vec_fields(data, "fields")?, + })) + } +} + +impl<'de, 'a> Visitor<'de> for SchemaVisitor<'a> { type Value = Schema; // Format a message stating what data this Visitor expects to receive. @@ -213,7 +236,7 @@ impl<'de> Visitor<'de> for SchemaVisitor { where D: Deserializer<'de>, { - deserializer.deserialize_any(SchemaVisitor {}) + deserializer.deserialize_any(self) } fn visit_none(self) -> Result @@ -227,8 +250,23 @@ impl<'de> Visitor<'de> for SchemaVisitor { where E: serde::de::Error, { - to_primitive(v) - .ok_or_else(|| serde::de::Error::custom("string must be a valid primitive Schema")) + let r#type = to_primitive(v); + if let Some(t) = r#type { + Ok(t) + } else { + // Previously defined types are referred using just a string + // containing the name of that previously defined type. + // That's why when the type cannot be parsed as primitive type, + // need to also check names of previously parsed types. + if let Some(t) = self.visited_types.get(&v.to_string()) { + Ok(t.clone()) + } else { + for vt in self.visited_types.keys() { + println!("Visited type: {vt}"); + } + Err(serde::de::Error::custom(format!("string '{v}' must be a valid primitive Schema"))) + } + } } fn visit_seq(self, mut seq: A) -> Result @@ -237,13 +275,13 @@ impl<'de> Visitor<'de> for SchemaVisitor { { let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(0)); while let Some(item) = seq.next_element::()? { - let schema: Schema = serde_json::from_value(item).map_err(serde::de::Error::custom)?; + let schema: Schema = deserialize_schema_from_value(item, self.visited_types).map_err(serde::de::Error::custom)?; vec.push(schema) } Ok(Schema::Union(vec)) } - fn visit_map(self, mut access: M) -> Result + fn visit_map(mut self, mut access: M) -> Result where M: MapAccess<'de>, { @@ -304,11 +342,32 @@ impl<'de> Visitor<'de> for SchemaVisitor { }) } else { match type_.as_ref() { - "enum" => to_enum(&mut map), - "map" => to_map(&mut map), - "array" => to_array(&mut map), - "record" => to_record(&mut map), - "fixed" => to_fixed(&mut map), + "enum" => { + let custom_type = to_enum(&mut map); + if let Ok(t) = custom_type.as_ref() { + let type_name = name_of_complex_type(t); + self.visited_types.insert(type_name, t.clone()); + } + custom_type + }, + "map" => self.to_map(&mut map), + "array" => self.to_array(&mut map), + "record" => { + let custom_type = self.to_record(&mut map); + if let Ok(t) = custom_type.as_ref() { + let type_name = name_of_complex_type(t); + self.visited_types.insert(type_name, t.clone()); + } + custom_type + }, + "fixed" => { + let custom_type = to_fixed(&mut map); + if let Ok(t) = custom_type.as_ref() { + let type_name = name_of_complex_type(t); + self.visited_types.insert(type_name, t.clone()); + } + custom_type + }, other => todo!("{}", other), } } @@ -320,13 +379,37 @@ impl<'de> Deserialize<'de> for Schema { where D: Deserializer<'de>, { - deserializer.deserialize_option(SchemaVisitor {}) + let mut visited = HashMap::new(); + deserializer.deserialize_option(SchemaVisitor { visited_types: &mut visited}) } } -struct FieldVisitor {} +/// AVRO schema allows so called ref types: +/// after a complex type (record, enum, fixed) is defined in a schema, +/// subsequent fields that have the same type can refer this +/// type by just specifying the type name. +/// +/// To keep information about previously parsed types, +/// we have to use a stateful deserializer based on DeserializeSeed. +struct SchemaDeserializer<'a> { + visited_types: &'a mut HashMap, +} + +impl<'de, 'a> DeserializeSeed<'de> for SchemaDeserializer<'a> { + type Value = Schema; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de> { + deserializer.deserialize_option(SchemaVisitor { visited_types: self.visited_types }) + } +} -impl<'de> Visitor<'de> for FieldVisitor { +struct FieldVisitor<'a> { + visited_types: &'a mut HashMap, +} + +impl<'de, 'a> Visitor<'de> for FieldVisitor<'a> { type Value = Field; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { @@ -345,24 +428,46 @@ impl<'de> Visitor<'de> for FieldVisitor { map.insert(key, value); } + let parsed_schema = + SchemaVisitor {visited_types: self.visited_types} + .to_schema(&mut map, "type"); + Ok(Field { name: remove_string(&mut map, "name")? .ok_or_else(|| serde::de::Error::custom("name is required in enum"))?, doc: remove_string(&mut map, "doc")?, - schema: to_schema(&mut map, "type")? + schema: parsed_schema? .ok_or_else(|| serde::de::Error::custom("type is required in Field"))?, - default: to_schema(&mut map, "default")?, + default: remove_json_value(&mut map, "default"), order: to_order(&mut map, "order")?, aliases: remove_vec_string(&mut map, "aliases")?, }) } } -impl<'de> Deserialize<'de> for Field { - fn deserialize(deserializer: D) -> Result +struct FieldDeserializer<'a> { + visited_types: &'a mut HashMap, +} + +impl<'de, 'a> DeserializeSeed<'de> for FieldDeserializer<'a> { + type Value = Field; + + fn deserialize(self, deserializer: D) -> Result where - D: Deserializer<'de>, - { - deserializer.deserialize_map(FieldVisitor {}) + D: Deserializer<'de> { + deserializer.deserialize_map(FieldVisitor { visited_types: self.visited_types}) + } +} + +fn name_of_complex_type(schema: &Schema) -> String { + match schema { + Schema::Record(record) => + format!("{}{}", record.namespace.as_ref().map(|s|format!(".{s}")).unwrap_or_default(), record.name), + Schema::Enum(r#enum) => + format!("{}{}", r#enum.namespace.as_ref().map(|s|format!(".{s}")).unwrap_or_default(), r#enum.name), + Schema::Fixed(fixed) => + format!("{}{}", fixed.namespace.as_ref().map(|s|format!(".{s}")).unwrap_or_default(), fixed.name), + _ => + unreachable!("Should not be called for {:?}", schema), } } diff --git a/src/schema/mod.rs b/src/schema/mod.rs index c7c4a3f..a20cb56 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -1,4 +1,7 @@ //! Contains structs defining Avro's logical types + +use serde_json::Value; + mod de; mod se; @@ -58,7 +61,7 @@ pub struct Field { /// Its Schema pub schema: Schema, /// Its default value - pub default: Option, + pub default: Option, /// Its optional order pub order: Option, /// Its aliases diff --git a/tests/it/main.rs b/tests/it/main.rs index 099bdcc..6775ccb 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -145,6 +145,70 @@ fn cases() -> Vec<(&'static str, Schema)> { ], }), ), + ( // Verify default is parsed correctly + r#"{ + "type":"record", + "name":"MyRecord", + "namespace":"com.company.department.schema.avro.sch", + "fields":[ + { + "name":"schemaVersion", + "type":"string", + "default":"1.6.8" + } + ] + }"#, + Record(avro_schema::schema::Record { + name: "MyRecord".to_string(), + namespace: Some("com.company.department.schema.avro.sch".to_string()), + doc: None, + aliases: vec![], + fields: vec![ + Field { + name: "schemaVersion".to_string(), + schema:Schema::String(None), + doc: None, + default: serde_json::to_value("1.6.8").ok(), + order: None, + aliases: vec![], + }, + ], + }), + ), + ( // Verify referred type is supported + r#"{ + "type":"record", + "name":"MyRecord", + "namespace":"com.company.department.schema.avro.sch", + "fields":[ + { + "name": "field1", + "type": { + "name": "MyType", + "type": "record" + } + }, + { + "name": "field2", + "type": "MyType" + } + ] + }"#, + Record(avro_schema::schema::Record { + name: "MyRecord".to_string(), + namespace: Some("com.company.department.schema.avro.sch".to_string()), + doc: None, + aliases: vec![], + fields: vec![ + Field::new("field1", avro_schema::schema::Record { + name: "MyType".to_string(), namespace: None, doc: None, aliases: vec![], fields: vec![] + }.into()), + Field::new("field2", avro_schema::schema::Record { + name: "MyType".to_string(), namespace: None, doc: None, aliases: vec![], fields: vec![] + }.into()), + ], + }), + ), ] }