From cff8bf8a82bdba8da7848a2372181af3327c6d83 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Tue, 30 Sep 2025 11:54:10 +0100 Subject: [PATCH 1/4] some work --- arrow-json/Cargo.toml | 2 +- arrow-json/src/reader/mod.rs | 2 +- arrow-json/src/reader/serializer.rs | 6 +- arrow-json/src/reader/tape.rs | 2 +- arrow-json/src/writer/encoder.rs | 2 +- arrow-schema/Cargo.toml | 12 +- .../extension/canonical/fixed_shape_tensor.rs | 142 +++++++++++++++++- arrow-schema/src/extension/canonical/json.rs | 77 +++++++++- .../src/extension/canonical/opaque.rs | 129 +++++++++++++++- .../canonical/variable_shape_tensor.rs | 142 +++++++++++++++++- 10 files changed, 494 insertions(+), 22 deletions(-) diff --git a/arrow-json/Cargo.toml b/arrow-json/Cargo.toml index 291bfb1906c9..191856f3bb47 100644 --- a/arrow-json/Cargo.toml +++ b/arrow-json/Cargo.toml @@ -44,7 +44,7 @@ arrow-schema = { workspace = true } half = { version = "2.1", default-features = false } indexmap = { version = "2.0", default-features = false, features = ["std"] } num-traits = { version = "0.2.19", default-features = false, features = ["std"] } -serde = { version = "1.0", default-features = false } +serde_core = { version = "1.0", default-features = false } serde_json = { version = "1.0", default-features = false, features = ["std"] } chrono = { workspace = true } lexical-core = { version = "1.0", default-features = false} diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index e4658f865314..9e8e76db7151 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -138,7 +138,7 @@ use std::io::BufRead; use std::sync::Arc; use chrono::Utc; -use serde::Serialize; +use serde_core::Serialize; use arrow_array::timezone::Tz; use arrow_array::types::*; diff --git a/arrow-json/src/reader/serializer.rs b/arrow-json/src/reader/serializer.rs index 95068af67833..5d004fbb5c9b 100644 --- a/arrow-json/src/reader/serializer.rs +++ b/arrow-json/src/reader/serializer.rs @@ -17,10 +17,10 @@ use crate::reader::tape::TapeElement; use lexical_core::FormattedSize; -use serde::ser::{ +use serde_core::ser::{ Impossible, SerializeMap, SerializeSeq, SerializeStruct, SerializeTuple, SerializeTupleStruct, }; -use serde::{Serialize, Serializer}; +use serde_core::{Serialize, Serializer}; #[derive(Debug)] pub struct SerializerError(String); @@ -33,7 +33,7 @@ impl std::fmt::Display for SerializerError { } } -impl serde::ser::Error for SerializerError { +impl serde_core::ser::Error for SerializerError { fn custom(msg: T) -> Self where T: std::fmt::Display, diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index e3e42ae1cc32..89ee3f778765 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -18,7 +18,7 @@ use crate::reader::serializer::TapeSerializer; use arrow_schema::ArrowError; use memchr::memchr2; -use serde::Serialize; +use serde_core::Serialize; use std::fmt::Write; /// We decode JSON to a flattened tape representation, diff --git a/arrow-json/src/writer/encoder.rs b/arrow-json/src/writer/encoder.rs index c960da3e0757..b562249fc527 100644 --- a/arrow-json/src/writer/encoder.rs +++ b/arrow-json/src/writer/encoder.rs @@ -26,7 +26,7 @@ use arrow_cast::display::{ArrayFormatter, FormatOptions}; use arrow_schema::{ArrowError, DataType, FieldRef}; use half::f16; use lexical_core::FormattedSize; -use serde::Serializer; +use serde_core::Serializer; /// Configuration options for the JSON encoder. #[derive(Debug, Clone, Default)] diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index d71d55496b1b..7c77d5279b42 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -33,8 +33,7 @@ name = "arrow_schema" bench = false [dependencies] -serde = { version = "1.0", default-features = false, features = [ - "derive", +serde_core = { version = "1.0", default-features = false, features = [ "std", "rc", ], optional = true } @@ -42,16 +41,19 @@ bitflags = { version = "2.0.0", default-features = false, optional = true } serde_json = { version = "1.0", optional = true } [features] -canonical_extension_types = ["dep:serde", "dep:serde_json"] +canonical_extension_types = ["dep:serde_core", "dep:serde_json"] # Enable ffi support ffi = ["bitflags"] -serde = ["dep:serde"] +serde = ["dep:serde_core"] [package.metadata.docs.rs] all-features = true [dev-dependencies] -bincode = { version = "2.0.1", default-features = false, features = ["std", "serde"] } +bincode = { version = "2.0.1", default-features = false, features = [ + "std", + "serde", +] } criterion = { version = "0.5", default-features = false } insta = "1.43.1" diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index 94258123aae7..6cd204321b84 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -19,7 +19,7 @@ //! //! -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::{ArrowError, DataType, extension::ExtensionType}; @@ -129,7 +129,7 @@ impl FixedShapeTensor { } /// Extension type metadata for [`FixedShapeTensor`]. -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq)] pub struct FixedShapeTensorMetadata { /// The physical shape of the contained tensors. shape: Vec, @@ -141,6 +141,144 @@ pub struct FixedShapeTensorMetadata { permutations: Option>, } +impl Serialize for FixedShapeTensorMetadata { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("FixedShapeTensorMetadata", 3)?; + state.serialize_field("shape", &self.shape)?; + state.serialize_field("dim_names", &self.dim_names)?; + state.serialize_field("permutations", &self.permutations)?; + state.end() + } +} + +impl<'de> Deserialize<'de> for FixedShapeTensorMetadata { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::{self, MapAccess, Visitor}; + use std::fmt; + + #[derive(Debug)] + enum Field { + Shape, + DimNames, + Permutations, + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`shape`, `dim_names`, or `permutations`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "shape" => Ok(Field::Shape), + "dim_names" => Ok(Field::DimNames), + "permutations" => Ok(Field::Permutations), + _ => Err(de::Error::unknown_field( + value, + &["shape", "dim_names", "permutations"], + )), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + + struct FixedShapeTensorMetadataVisitor; + + impl<'de> Visitor<'de> for FixedShapeTensorMetadataVisitor { + type Value = FixedShapeTensorMetadata; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct FixedShapeTensorMetadata") + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: de::SeqAccess<'de>, + { + let shape = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let dim_names = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let permutations = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + Ok(FixedShapeTensorMetadata { + shape, + dim_names, + permutations, + }) + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut shape = None; + let mut dim_names = None; + let mut permutations = None; + + while let Some(key) = map.next_key()? { + match key { + Field::Shape => { + if shape.is_some() { + return Err(de::Error::duplicate_field("shape")); + } + shape = Some(map.next_value()?); + } + Field::DimNames => { + if dim_names.is_some() { + return Err(de::Error::duplicate_field("dim_names")); + } + dim_names = Some(map.next_value()?); + } + Field::Permutations => { + if permutations.is_some() { + return Err(de::Error::duplicate_field("permutations")); + } + permutations = Some(map.next_value()?); + } + } + } + + let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?; + + Ok(FixedShapeTensorMetadata { + shape, + dim_names, + permutations, + }) + } + } + + const FIELDS: &[&str] = &["shape", "dim_names", "permutations"]; + deserializer.deserialize_struct("FixedShapeTensorMetadata", FIELDS, FixedShapeTensorMetadataVisitor) + } +} + impl FixedShapeTensorMetadata { /// Returns metadata for a fixed shape tensor extension type. /// diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index 366094510439..01b94980b2db 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -19,7 +19,10 @@ //! //! -use serde::{Deserialize, Serialize}; +use serde_core::de::{self, MapAccess, Visitor}; +use serde_core::ser::SerializeStruct; +use serde_core::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; use crate::{ArrowError, DataType, extension::ExtensionType}; @@ -42,10 +45,78 @@ use crate::{ArrowError, DataType, extension::ExtensionType}; pub struct Json(JsonMetadata); /// Empty object -#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)] -#[serde(deny_unknown_fields)] +#[derive(Debug, Clone, Copy, PartialEq)] struct Empty {} +impl Serialize for Empty { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let state = serializer.serialize_struct("Empty", 0)?; + state.end() + } +} + +static EMPTY_FIELDS: &[&str] = &[]; + +impl<'de> Deserialize<'de> for Empty { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct EmptyVisitor; + + impl<'de> Visitor<'de> for EmptyVisitor { + type Value = Empty; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Empty") + } + + fn visit_seq(self, mut _seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + Ok(Empty {}) + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + if let Some(key) = map.next_key::()? { + return Err(de::Error::unknown_field(&key, EMPTY_FIELDS)); + } + Ok(Empty {}) + } + + fn visit_u64(self, _v: u64) -> Result + where + E: de::Error, + { + Err(de::Error::unknown_field("", EMPTY_FIELDS)) + } + + fn visit_str(self, _v: &str) -> Result + where + E: de::Error, + { + Err(de::Error::unknown_field("", EMPTY_FIELDS)) + } + + fn visit_bytes(self, _v: &[u8]) -> Result + where + E: de::Error, + { + Err(de::Error::unknown_field("", EMPTY_FIELDS)) + } + } + + deserializer.deserialize_struct("Empty", EMPTY_FIELDS, EmptyVisitor) + } +} + /// Extension type metadata for [`Json`]. #[derive(Debug, Default, Clone, PartialEq)] pub struct JsonMetadata(Option); diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index 5aa064e6d386..cd12e615bdb5 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -19,7 +19,10 @@ //! //! -use serde::{Deserialize, Serialize}; +use serde_core::{ + Deserialize, Deserializer, Serialize, Serializer, + de::{MapAccess, Visitor}, +}; use crate::{ArrowError, DataType, extension::ExtensionType}; @@ -61,7 +64,7 @@ impl From for Opaque { } /// Extension type metadata for [`Opaque`]. -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq)] pub struct OpaqueMetadata { /// Name of the unknown type in the external system. type_name: String, @@ -70,6 +73,128 @@ pub struct OpaqueMetadata { vendor_name: String, } +impl Serialize for OpaqueMetadata { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("OpaqueMetadata", 2)?; + state.serialize_field("type_name", &self.type_name)?; + state.serialize_field("vendor_name", &self.vendor_name)?; + state.end() + } +} + +impl<'de> Deserialize<'de> for OpaqueMetadata { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Debug)] + enum Field { + TypeName, + VendorName, + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("`type_name` or `vendor_name`") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde_core::de::Error, + { + match value { + "type_name" => Ok(Field::TypeName), + "vendor_name" => Ok(Field::VendorName), + _ => Err(serde_core::de::Error::unknown_field( + value, + &["type_name", "vendor_name"], + )), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + + struct OpaqueMetadataVisitor; + + impl<'de> Visitor<'de> for OpaqueMetadataVisitor { + type Value = OpaqueMetadata; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct OpaqueMetadata") + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: serde_core::de::SeqAccess<'de>, + { + let type_name = seq + .next_element()? + .ok_or_else(|| serde_core::de::Error::invalid_length(0, &self))?; + let vendor_name = seq + .next_element()? + .ok_or_else(|| serde_core::de::Error::invalid_length(1, &self))?; + Ok(OpaqueMetadata { + type_name, + vendor_name, + }) + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut type_name = None; + let mut vendor_name = None; + + while let Some(key) = map.next_key()? { + match key { + Field::TypeName => { + if type_name.is_some() { + return Err(serde_core::de::Error::duplicate_field("type_name")); + } + type_name = Some(map.next_value()?); + } + Field::VendorName => { + if vendor_name.is_some() { + return Err(serde_core::de::Error::duplicate_field("vendor_name")); + } + vendor_name = Some(map.next_value()?); + } + } + } + + let type_name = + type_name.ok_or_else(|| serde_core::de::Error::missing_field("type_name"))?; + let vendor_name = vendor_name + .ok_or_else(|| serde_core::de::Error::missing_field("vendor_name"))?; + + Ok(OpaqueMetadata { + type_name, + vendor_name, + }) + } + } + + const FIELDS: &[&str] = &["type_name", "vendor_name"]; + deserializer.deserialize_struct("OpaqueMetadata", FIELDS, OpaqueMetadataVisitor) + } +} + impl OpaqueMetadata { /// Returns a new `OpaqueMetadata`. pub fn new(type_name: impl Into, vendor_name: impl Into) -> Self { diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index 45fde67ee791..1c9f97c02ae7 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -19,7 +19,7 @@ //! //! -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::{ArrowError, DataType, Field, extension::ExtensionType}; @@ -140,7 +140,7 @@ impl VariableShapeTensor { } /// Extension type metadata for [`VariableShapeTensor`]. -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq)] pub struct VariableShapeTensorMetadata { /// Explicit names to tensor dimensions. dim_names: Option>, @@ -148,11 +148,147 @@ pub struct VariableShapeTensorMetadata { /// Indices of the desired ordering of the original dimensions. permutations: Option>, - /// Sizes of individual tensor’s dimensions which are guaranteed to stay + /// Sizes of individual tensor's dimensions which are guaranteed to stay /// constant in uniform dimensions and can vary in non-uniform dimensions. uniform_shape: Option>>, } +impl Serialize for VariableShapeTensorMetadata { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("VariableShapeTensorMetadata", 3)?; + state.serialize_field("dim_names", &self.dim_names)?; + state.serialize_field("permutations", &self.permutations)?; + state.serialize_field("uniform_shape", &self.uniform_shape)?; + state.end() + } +} + +impl<'de> Deserialize<'de> for VariableShapeTensorMetadata { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::{self, MapAccess, Visitor}; + use std::fmt; + + #[derive(Debug)] + enum Field { + DimNames, + Permutations, + UniformShape, + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`dim_names`, `permutations`, or `uniform_shape`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "dim_names" => Ok(Field::DimNames), + "permutations" => Ok(Field::Permutations), + "uniform_shape" => Ok(Field::UniformShape), + _ => Err(de::Error::unknown_field( + value, + &["dim_names", "permutations", "uniform_shape"], + )), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + + struct VariableShapeTensorMetadataVisitor; + + impl<'de> Visitor<'de> for VariableShapeTensorMetadataVisitor { + type Value = VariableShapeTensorMetadata; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct VariableShapeTensorMetadata") + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: de::SeqAccess<'de>, + { + let dim_names = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let permutations = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let uniform_shape = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + Ok(VariableShapeTensorMetadata { + dim_names, + permutations, + uniform_shape, + }) + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut dim_names = None; + let mut permutations = None; + let mut uniform_shape = None; + + while let Some(key) = map.next_key()? { + match key { + Field::DimNames => { + if dim_names.is_some() { + return Err(de::Error::duplicate_field("dim_names")); + } + dim_names = Some(map.next_value()?); + } + Field::Permutations => { + if permutations.is_some() { + return Err(de::Error::duplicate_field("permutations")); + } + permutations = Some(map.next_value()?); + } + Field::UniformShape => { + if uniform_shape.is_some() { + return Err(de::Error::duplicate_field("uniform_shape")); + } + uniform_shape = Some(map.next_value()?); + } + } + } + + Ok(VariableShapeTensorMetadata { + dim_names, + permutations, + uniform_shape, + }) + } + } + + const FIELDS: &[&str] = &["dim_names", "permutations", "uniform_shape"]; + deserializer.deserialize_struct("VariableShapeTensorMetadata", FIELDS, VariableShapeTensorMetadataVisitor) + } +} + impl VariableShapeTensorMetadata { /// Returns metadata for a variable shape tensor extension type. /// From bf4ce13e16394bb7c3bffcf6a5ccb806a4dda197 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Mon, 6 Oct 2025 11:14:18 +0100 Subject: [PATCH 2/4] Followup --- arrow-integration-testing/Cargo.toml | 1 - arrow-schema/Cargo.toml | 5 ++++- .../extension/canonical/fixed_shape_tensor.rs | 20 ++++++++++--------- .../src/extension/canonical/opaque.rs | 1 + .../canonical/variable_shape_tensor.rs | 15 ++++++++------ 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/arrow-integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml index 35eb47b8d681..ae13d32b57a9 100644 --- a/arrow-integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -40,7 +40,6 @@ arrow-integration-test = { path = "../arrow-integration-test", default-features clap = { version = "4", default-features = false, features = ["std", "derive", "help", "error-context", "usage"] } futures = { version = "0.3", default-features = false } prost = { version = "0.14.1", default-features = false } -serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.0", default-features = false, features = [ "rt-multi-thread"] } tonic = { version = "0.14.1", default-features = false } diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml index 7c77d5279b42..bcc613b862cb 100644 --- a/arrow-schema/Cargo.toml +++ b/arrow-schema/Cargo.toml @@ -37,6 +37,9 @@ serde_core = { version = "1.0", default-features = false, features = [ "std", "rc", ], optional = true } +serde = { version = "1.0", default-features = false, features = [ + "derive", +], optional = true } bitflags = { version = "2.0.0", default-features = false, optional = true } serde_json = { version = "1.0", optional = true } @@ -44,7 +47,7 @@ serde_json = { version = "1.0", optional = true } canonical_extension_types = ["dep:serde_core", "dep:serde_json"] # Enable ffi support ffi = ["bitflags"] -serde = ["dep:serde_core"] +serde = ["dep:serde_core", "dep:serde"] [package.metadata.docs.rs] all-features = true diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index 6cd204321b84..63b2f79ba1af 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -19,7 +19,10 @@ //! //! -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_core::de::{self, MapAccess, Visitor}; +use serde_core::ser::SerializeStruct; +use serde_core::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; use crate::{ArrowError, DataType, extension::ExtensionType}; @@ -146,7 +149,6 @@ impl Serialize for FixedShapeTensorMetadata { where S: Serializer, { - use serde::ser::SerializeStruct; let mut state = serializer.serialize_struct("FixedShapeTensorMetadata", 3)?; state.serialize_field("shape", &self.shape)?; state.serialize_field("dim_names", &self.dim_names)?; @@ -160,9 +162,6 @@ impl<'de> Deserialize<'de> for FixedShapeTensorMetadata { where D: Deserializer<'de>, { - use serde::de::{self, MapAccess, Visitor}; - use std::fmt; - #[derive(Debug)] enum Field { Shape, @@ -275,7 +274,11 @@ impl<'de> Deserialize<'de> for FixedShapeTensorMetadata { } const FIELDS: &[&str] = &["shape", "dim_names", "permutations"]; - deserializer.deserialize_struct("FixedShapeTensorMetadata", FIELDS, FixedShapeTensorMetadataVisitor) + deserializer.deserialize_struct( + "FixedShapeTensorMetadata", + FIELDS, + FixedShapeTensorMetadataVisitor, + ) } } @@ -515,9 +518,8 @@ mod tests { } #[test] - #[should_panic( - expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`" - )] + #[should_panic(expected = "FixedShapeTensor metadata deserialization failed: \ + unknown field `not-shape`, expected one of `shape`, `dim_names`, `permutations`")] fn invalid_metadata() { let fixed_shape_tensor = FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap(); diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index cd12e615bdb5..5cd9c2982c6c 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -19,6 +19,7 @@ //! //! +use serde_core::ser::SerializeStruct; use serde_core::{ Deserialize, Deserializer, Serialize, Serializer, de::{MapAccess, Visitor}, diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index 1c9f97c02ae7..cd1e459a4b6c 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -19,7 +19,9 @@ //! //! -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_core::de::{self, MapAccess, Visitor}; +use serde_core::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; use crate::{ArrowError, DataType, Field, extension::ExtensionType}; @@ -158,7 +160,7 @@ impl Serialize for VariableShapeTensorMetadata { where S: Serializer, { - use serde::ser::SerializeStruct; + use serde_core::ser::SerializeStruct; let mut state = serializer.serialize_struct("VariableShapeTensorMetadata", 3)?; state.serialize_field("dim_names", &self.dim_names)?; state.serialize_field("permutations", &self.permutations)?; @@ -172,9 +174,6 @@ impl<'de> Deserialize<'de> for VariableShapeTensorMetadata { where D: Deserializer<'de>, { - use serde::de::{self, MapAccess, Visitor}; - use std::fmt; - #[derive(Debug)] enum Field { DimNames, @@ -285,7 +284,11 @@ impl<'de> Deserialize<'de> for VariableShapeTensorMetadata { } const FIELDS: &[&str] = &["dim_names", "permutations", "uniform_shape"]; - deserializer.deserialize_struct("VariableShapeTensorMetadata", FIELDS, VariableShapeTensorMetadataVisitor) + deserializer.deserialize_struct( + "VariableShapeTensorMetadata", + FIELDS, + VariableShapeTensorMetadataVisitor, + ) } } From 243f29507365aa9b6ed47e64ce1dd2648b850e0e Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Mon, 6 Oct 2025 11:52:02 +0100 Subject: [PATCH 3/4] Fix docs link --- arrow-json/src/reader/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 9e8e76db7151..c47aa65f81c4 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -613,6 +613,8 @@ impl Decoder { /// ``` /// /// Note: this ignores any batch size setting, and always decodes all rows + /// + /// [serde]: https://docs.rs/serde/latest/serde/ pub fn serialize(&mut self, rows: &[S]) -> Result<(), ArrowError> { self.tape_decoder.serialize(rows) } From 88e1e2b3d07d1a9176e2364fdc8bd265ec5c1b92 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Mon, 6 Oct 2025 16:17:58 +0100 Subject: [PATCH 4/4] Move serde support structs outside of serde trait implementations --- .../extension/canonical/fixed_shape_tensor.rs | 205 +++++++++--------- arrow-schema/src/extension/canonical/json.rs | 96 ++++---- .../src/extension/canonical/opaque.rs | 187 ++++++++-------- .../canonical/variable_shape_tensor.rs | 201 +++++++++-------- 4 files changed, 345 insertions(+), 344 deletions(-) diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs index 63b2f79ba1af..b6bd1c1223f4 100644 --- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs @@ -157,126 +157,125 @@ impl Serialize for FixedShapeTensorMetadata { } } -impl<'de> Deserialize<'de> for FixedShapeTensorMetadata { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Debug)] - enum Field { - Shape, - DimNames, - Permutations, - } - - impl<'de> Deserialize<'de> for Field { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct FieldVisitor; +#[derive(Debug)] +enum MetadataField { + Shape, + DimNames, + Permutations, +} - impl<'de> Visitor<'de> for FieldVisitor { - type Value = Field; +struct MetadataFieldVisitor; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("`shape`, `dim_names`, or `permutations`") - } +impl<'de> Visitor<'de> for MetadataFieldVisitor { + type Value = MetadataField; - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match value { - "shape" => Ok(Field::Shape), - "dim_names" => Ok(Field::DimNames), - "permutations" => Ok(Field::Permutations), - _ => Err(de::Error::unknown_field( - value, - &["shape", "dim_names", "permutations"], - )), - } - } - } + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`shape`, `dim_names`, or `permutations`") + } - deserializer.deserialize_identifier(FieldVisitor) - } + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "shape" => Ok(MetadataField::Shape), + "dim_names" => Ok(MetadataField::DimNames), + "permutations" => Ok(MetadataField::Permutations), + _ => Err(de::Error::unknown_field( + value, + &["shape", "dim_names", "permutations"], + )), } + } +} + +impl<'de> Deserialize<'de> for MetadataField { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_identifier(MetadataFieldVisitor) + } +} - struct FixedShapeTensorMetadataVisitor; +struct FixedShapeTensorMetadataVisitor; - impl<'de> Visitor<'de> for FixedShapeTensorMetadataVisitor { - type Value = FixedShapeTensorMetadata; +impl<'de> Visitor<'de> for FixedShapeTensorMetadataVisitor { + type Value = FixedShapeTensorMetadata; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("struct FixedShapeTensorMetadata") - } + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct FixedShapeTensorMetadata") + } - fn visit_seq(self, mut seq: V) -> Result - where - V: de::SeqAccess<'de>, - { - let shape = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(0, &self))?; - let dim_names = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(1, &self))?; - let permutations = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(2, &self))?; - Ok(FixedShapeTensorMetadata { - shape, - dim_names, - permutations, - }) - } + fn visit_seq(self, mut seq: V) -> Result + where + V: de::SeqAccess<'de>, + { + let shape = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let dim_names = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let permutations = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + Ok(FixedShapeTensorMetadata { + shape, + dim_names, + permutations, + }) + } - fn visit_map(self, mut map: V) -> Result - where - V: MapAccess<'de>, - { - let mut shape = None; - let mut dim_names = None; - let mut permutations = None; - - while let Some(key) = map.next_key()? { - match key { - Field::Shape => { - if shape.is_some() { - return Err(de::Error::duplicate_field("shape")); - } - shape = Some(map.next_value()?); - } - Field::DimNames => { - if dim_names.is_some() { - return Err(de::Error::duplicate_field("dim_names")); - } - dim_names = Some(map.next_value()?); - } - Field::Permutations => { - if permutations.is_some() { - return Err(de::Error::duplicate_field("permutations")); - } - permutations = Some(map.next_value()?); - } + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut shape = None; + let mut dim_names = None; + let mut permutations = None; + + while let Some(key) = map.next_key()? { + match key { + MetadataField::Shape => { + if shape.is_some() { + return Err(de::Error::duplicate_field("shape")); } + shape = Some(map.next_value()?); + } + MetadataField::DimNames => { + if dim_names.is_some() { + return Err(de::Error::duplicate_field("dim_names")); + } + dim_names = Some(map.next_value()?); + } + MetadataField::Permutations => { + if permutations.is_some() { + return Err(de::Error::duplicate_field("permutations")); + } + permutations = Some(map.next_value()?); } - - let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?; - - Ok(FixedShapeTensorMetadata { - shape, - dim_names, - permutations, - }) } } - const FIELDS: &[&str] = &["shape", "dim_names", "permutations"]; + let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?; + + Ok(FixedShapeTensorMetadata { + shape, + dim_names, + permutations, + }) + } +} + +impl<'de> Deserialize<'de> for FixedShapeTensorMetadata { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { deserializer.deserialize_struct( "FixedShapeTensorMetadata", - FIELDS, + &["shape", "dim_names", "permutations"], FixedShapeTensorMetadataVisitor, ) } diff --git a/arrow-schema/src/extension/canonical/json.rs b/arrow-schema/src/extension/canonical/json.rs index 01b94980b2db..297a2d99aa04 100644 --- a/arrow-schema/src/extension/canonical/json.rs +++ b/arrow-schema/src/extension/canonical/json.rs @@ -58,6 +58,54 @@ impl Serialize for Empty { } } +struct EmptyVisitor; + +impl<'de> Visitor<'de> for EmptyVisitor { + type Value = Empty; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Empty") + } + + fn visit_seq(self, mut _seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + Ok(Empty {}) + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + if let Some(key) = map.next_key::()? { + return Err(de::Error::unknown_field(&key, EMPTY_FIELDS)); + } + Ok(Empty {}) + } + + fn visit_u64(self, _v: u64) -> Result + where + E: de::Error, + { + Err(de::Error::unknown_field("", EMPTY_FIELDS)) + } + + fn visit_str(self, _v: &str) -> Result + where + E: de::Error, + { + Err(de::Error::unknown_field("", EMPTY_FIELDS)) + } + + fn visit_bytes(self, _v: &[u8]) -> Result + where + E: de::Error, + { + Err(de::Error::unknown_field("", EMPTY_FIELDS)) + } +} + static EMPTY_FIELDS: &[&str] = &[]; impl<'de> Deserialize<'de> for Empty { @@ -65,54 +113,6 @@ impl<'de> Deserialize<'de> for Empty { where D: Deserializer<'de>, { - struct EmptyVisitor; - - impl<'de> Visitor<'de> for EmptyVisitor { - type Value = Empty; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("struct Empty") - } - - fn visit_seq(self, mut _seq: A) -> Result - where - A: de::SeqAccess<'de>, - { - Ok(Empty {}) - } - - fn visit_map(self, mut map: V) -> Result - where - V: MapAccess<'de>, - { - if let Some(key) = map.next_key::()? { - return Err(de::Error::unknown_field(&key, EMPTY_FIELDS)); - } - Ok(Empty {}) - } - - fn visit_u64(self, _v: u64) -> Result - where - E: de::Error, - { - Err(de::Error::unknown_field("", EMPTY_FIELDS)) - } - - fn visit_str(self, _v: &str) -> Result - where - E: de::Error, - { - Err(de::Error::unknown_field("", EMPTY_FIELDS)) - } - - fn visit_bytes(self, _v: &[u8]) -> Result - where - E: de::Error, - { - Err(de::Error::unknown_field("", EMPTY_FIELDS)) - } - } - deserializer.deserialize_struct("Empty", EMPTY_FIELDS, EmptyVisitor) } } diff --git a/arrow-schema/src/extension/canonical/opaque.rs b/arrow-schema/src/extension/canonical/opaque.rs index 5cd9c2982c6c..fceae8d3711d 100644 --- a/arrow-schema/src/extension/canonical/opaque.rs +++ b/arrow-schema/src/extension/canonical/opaque.rs @@ -86,113 +86,116 @@ impl Serialize for OpaqueMetadata { } } -impl<'de> Deserialize<'de> for OpaqueMetadata { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Debug)] - enum Field { - TypeName, - VendorName, - } - - impl<'de> Deserialize<'de> for Field { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct FieldVisitor; +#[derive(Debug)] +enum MetadataField { + TypeName, + VendorName, +} - impl<'de> Visitor<'de> for FieldVisitor { - type Value = Field; +struct MetadataFieldVisitor; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("`type_name` or `vendor_name`") - } +impl<'de> Visitor<'de> for MetadataFieldVisitor { + type Value = MetadataField; - fn visit_str(self, value: &str) -> Result - where - E: serde_core::de::Error, - { - match value { - "type_name" => Ok(Field::TypeName), - "vendor_name" => Ok(Field::VendorName), - _ => Err(serde_core::de::Error::unknown_field( - value, - &["type_name", "vendor_name"], - )), - } - } - } + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("`type_name` or `vendor_name`") + } - deserializer.deserialize_identifier(FieldVisitor) - } + fn visit_str(self, value: &str) -> Result + where + E: serde_core::de::Error, + { + match value { + "type_name" => Ok(MetadataField::TypeName), + "vendor_name" => Ok(MetadataField::VendorName), + _ => Err(serde_core::de::Error::unknown_field( + value, + &["type_name", "vendor_name"], + )), } + } +} - struct OpaqueMetadataVisitor; +impl<'de> Deserialize<'de> for MetadataField { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_identifier(MetadataFieldVisitor) + } +} - impl<'de> Visitor<'de> for OpaqueMetadataVisitor { - type Value = OpaqueMetadata; +struct OpaqueMetadataVisitor; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("struct OpaqueMetadata") - } +impl<'de> Visitor<'de> for OpaqueMetadataVisitor { + type Value = OpaqueMetadata; - fn visit_seq(self, mut seq: V) -> Result - where - V: serde_core::de::SeqAccess<'de>, - { - let type_name = seq - .next_element()? - .ok_or_else(|| serde_core::de::Error::invalid_length(0, &self))?; - let vendor_name = seq - .next_element()? - .ok_or_else(|| serde_core::de::Error::invalid_length(1, &self))?; - Ok(OpaqueMetadata { - type_name, - vendor_name, - }) - } + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct OpaqueMetadata") + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: serde_core::de::SeqAccess<'de>, + { + let type_name = seq + .next_element()? + .ok_or_else(|| serde_core::de::Error::invalid_length(0, &self))?; + let vendor_name = seq + .next_element()? + .ok_or_else(|| serde_core::de::Error::invalid_length(1, &self))?; + Ok(OpaqueMetadata { + type_name, + vendor_name, + }) + } - fn visit_map(self, mut map: V) -> Result - where - V: MapAccess<'de>, - { - let mut type_name = None; - let mut vendor_name = None; - - while let Some(key) = map.next_key()? { - match key { - Field::TypeName => { - if type_name.is_some() { - return Err(serde_core::de::Error::duplicate_field("type_name")); - } - type_name = Some(map.next_value()?); - } - Field::VendorName => { - if vendor_name.is_some() { - return Err(serde_core::de::Error::duplicate_field("vendor_name")); - } - vendor_name = Some(map.next_value()?); - } + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut type_name = None; + let mut vendor_name = None; + + while let Some(key) = map.next_key()? { + match key { + MetadataField::TypeName => { + if type_name.is_some() { + return Err(serde_core::de::Error::duplicate_field("type_name")); } + type_name = Some(map.next_value()?); + } + MetadataField::VendorName => { + if vendor_name.is_some() { + return Err(serde_core::de::Error::duplicate_field("vendor_name")); + } + vendor_name = Some(map.next_value()?); } - - let type_name = - type_name.ok_or_else(|| serde_core::de::Error::missing_field("type_name"))?; - let vendor_name = vendor_name - .ok_or_else(|| serde_core::de::Error::missing_field("vendor_name"))?; - - Ok(OpaqueMetadata { - type_name, - vendor_name, - }) } } - const FIELDS: &[&str] = &["type_name", "vendor_name"]; - deserializer.deserialize_struct("OpaqueMetadata", FIELDS, OpaqueMetadataVisitor) + let type_name = + type_name.ok_or_else(|| serde_core::de::Error::missing_field("type_name"))?; + let vendor_name = + vendor_name.ok_or_else(|| serde_core::de::Error::missing_field("vendor_name"))?; + + Ok(OpaqueMetadata { + type_name, + vendor_name, + }) + } +} + +impl<'de> Deserialize<'de> for OpaqueMetadata { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct( + "OpaqueMetadata", + &["type_name", "vendor_name"], + OpaqueMetadataVisitor, + ) } } diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs index cd1e459a4b6c..b5403dcf684f 100644 --- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs +++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs @@ -169,124 +169,123 @@ impl Serialize for VariableShapeTensorMetadata { } } -impl<'de> Deserialize<'de> for VariableShapeTensorMetadata { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Debug)] - enum Field { - DimNames, - Permutations, - UniformShape, - } - - impl<'de> Deserialize<'de> for Field { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct FieldVisitor; +#[derive(Debug)] +enum MetadataField { + DimNames, + Permutations, + UniformShape, +} - impl<'de> Visitor<'de> for FieldVisitor { - type Value = Field; +struct MetadataFieldVisitor; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("`dim_names`, `permutations`, or `uniform_shape`") - } +impl<'de> Visitor<'de> for MetadataFieldVisitor { + type Value = MetadataField; - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match value { - "dim_names" => Ok(Field::DimNames), - "permutations" => Ok(Field::Permutations), - "uniform_shape" => Ok(Field::UniformShape), - _ => Err(de::Error::unknown_field( - value, - &["dim_names", "permutations", "uniform_shape"], - )), - } - } - } + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`dim_names`, `permutations`, or `uniform_shape`") + } - deserializer.deserialize_identifier(FieldVisitor) - } + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "dim_names" => Ok(MetadataField::DimNames), + "permutations" => Ok(MetadataField::Permutations), + "uniform_shape" => Ok(MetadataField::UniformShape), + _ => Err(de::Error::unknown_field( + value, + &["dim_names", "permutations", "uniform_shape"], + )), } + } +} - struct VariableShapeTensorMetadataVisitor; +impl<'de> Deserialize<'de> for MetadataField { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_identifier(MetadataFieldVisitor) + } +} - impl<'de> Visitor<'de> for VariableShapeTensorMetadataVisitor { - type Value = VariableShapeTensorMetadata; +struct VariableShapeTensorMetadataVisitor; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("struct VariableShapeTensorMetadata") - } +impl<'de> Visitor<'de> for VariableShapeTensorMetadataVisitor { + type Value = VariableShapeTensorMetadata; - fn visit_seq(self, mut seq: V) -> Result - where - V: de::SeqAccess<'de>, - { - let dim_names = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(0, &self))?; - let permutations = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(1, &self))?; - let uniform_shape = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(2, &self))?; - Ok(VariableShapeTensorMetadata { - dim_names, - permutations, - uniform_shape, - }) - } + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct VariableShapeTensorMetadata") + } - fn visit_map(self, mut map: V) -> Result - where - V: MapAccess<'de>, - { - let mut dim_names = None; - let mut permutations = None; - let mut uniform_shape = None; - - while let Some(key) = map.next_key()? { - match key { - Field::DimNames => { - if dim_names.is_some() { - return Err(de::Error::duplicate_field("dim_names")); - } - dim_names = Some(map.next_value()?); - } - Field::Permutations => { - if permutations.is_some() { - return Err(de::Error::duplicate_field("permutations")); - } - permutations = Some(map.next_value()?); - } - Field::UniformShape => { - if uniform_shape.is_some() { - return Err(de::Error::duplicate_field("uniform_shape")); - } - uniform_shape = Some(map.next_value()?); - } + fn visit_seq(self, mut seq: V) -> Result + where + V: de::SeqAccess<'de>, + { + let dim_names = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let permutations = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let uniform_shape = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + Ok(VariableShapeTensorMetadata { + dim_names, + permutations, + uniform_shape, + }) + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut dim_names = None; + let mut permutations = None; + let mut uniform_shape = None; + + while let Some(key) = map.next_key()? { + match key { + MetadataField::DimNames => { + if dim_names.is_some() { + return Err(de::Error::duplicate_field("dim_names")); } + dim_names = Some(map.next_value()?); + } + MetadataField::Permutations => { + if permutations.is_some() { + return Err(de::Error::duplicate_field("permutations")); + } + permutations = Some(map.next_value()?); + } + MetadataField::UniformShape => { + if uniform_shape.is_some() { + return Err(de::Error::duplicate_field("uniform_shape")); + } + uniform_shape = Some(map.next_value()?); } - - Ok(VariableShapeTensorMetadata { - dim_names, - permutations, - uniform_shape, - }) } } - const FIELDS: &[&str] = &["dim_names", "permutations", "uniform_shape"]; + Ok(VariableShapeTensorMetadata { + dim_names, + permutations, + uniform_shape, + }) + } +} + +impl<'de> Deserialize<'de> for VariableShapeTensorMetadata { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { deserializer.deserialize_struct( "VariableShapeTensorMetadata", - FIELDS, + &["dim_names", "permutations", "uniform_shape"], VariableShapeTensorMetadataVisitor, ) }