diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index e999f505bca1..49b34c6137f7 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -2,7 +2,7 @@ We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. -Closes #NNN. +- Closes #NNN. # Rationale for this change @@ -13,6 +13,14 @@ Explaining clearly why changes are proposed helps reviewers understand your chan There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. +# Are these changes tested? + +We typically require tests for all PRs in order to: +1. Prevent the code from being accidentally broken by subsequent changes +2. Serve as another way to document the expected behavior of the code + +If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? + # Are there any user-facing changes? If there are user-facing changes then we may require documentation to be updated before approving the PR. diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index c595d72e0afc..1c6ebe23d24f 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. use crate::decoder::{VariantBasicType, VariantPrimitiveType}; -use crate::{ShortString, Variant}; -use std::collections::HashMap; +use crate::{ShortString, Variant, VariantDecimal16, VariantDecimal4, VariantDecimal8}; +use std::collections::BTreeMap; const BASIC_TYPE_BITS: u8 = 2; const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); @@ -166,7 +166,7 @@ fn make_room_for_header(buffer: &mut Vec, start_pos: usize, header_size: usi /// pub struct VariantBuilder { buffer: Vec, - dict: HashMap, + dict: BTreeMap, dict_keys: Vec, } @@ -174,7 +174,7 @@ impl VariantBuilder { pub fn new() -> Self { Self { buffer: Vec::new(), - dict: HashMap::new(), + dict: BTreeMap::new(), dict_keys: Vec::new(), } } @@ -296,7 +296,7 @@ impl VariantBuilder { /// Add key to dictionary, return its ID fn add_key(&mut self, key: &str) -> u32 { - use std::collections::hash_map::Entry; + use std::collections::btree_map::Entry; match self.dict.entry(key.to_string()) { Entry::Occupied(entry) => *entry.get(), Entry::Vacant(entry) => { @@ -384,9 +384,15 @@ impl VariantBuilder { Variant::Date(v) => self.append_date(v), Variant::TimestampMicros(v) => self.append_timestamp_micros(v), Variant::TimestampNtzMicros(v) => self.append_timestamp_ntz_micros(v), - Variant::Decimal4 { integer, scale } => self.append_decimal4(integer, scale), - Variant::Decimal8 { integer, scale } => self.append_decimal8(integer, scale), - Variant::Decimal16 { integer, scale } => self.append_decimal16(integer, scale), + Variant::Decimal4(VariantDecimal4 { integer, scale }) => { + self.append_decimal4(integer, scale) + } + Variant::Decimal8(VariantDecimal8 { integer, scale }) => { + self.append_decimal8(integer, scale) + } + Variant::Decimal16(VariantDecimal16 { integer, scale }) => { + self.append_decimal16(integer, scale) + } Variant::Float(v) => self.append_float(v), Variant::Double(v) => self.append_double(v), Variant::Binary(v) => self.append_binary(v), @@ -482,7 +488,7 @@ impl<'a> ListBuilder<'a> { pub struct ObjectBuilder<'a> { parent: &'a mut VariantBuilder, start_pos: usize, - fields: Vec<(u32, usize)>, // (field_id, offset) + fields: BTreeMap, // (field_id, offset) } impl<'a> ObjectBuilder<'a> { @@ -491,7 +497,7 @@ impl<'a> ObjectBuilder<'a> { Self { parent, start_pos, - fields: Vec::new(), + fields: BTreeMap::new(), } } @@ -500,25 +506,27 @@ impl<'a> ObjectBuilder<'a> { let id = self.parent.add_key(key); let field_start = self.parent.offset() - self.start_pos; self.parent.append_value(value); - self.fields.push((id, field_start)); + let res = self.fields.insert(id, field_start); + debug_assert!(res.is_none()); } /// Finalize object with sorted fields - pub fn finish(mut self) { - // Sort fields by key name - self.fields.sort_by(|a, b| { - let key_a = &self.parent.dict_keys[a.0 as usize]; - let key_b = &self.parent.dict_keys[b.0 as usize]; - key_a.cmp(key_b) - }); - + pub fn finish(self) { let data_size = self.parent.offset() - self.start_pos; let num_fields = self.fields.len(); let is_large = num_fields > u8::MAX as usize; let size_bytes = if is_large { 4 } else { 1 }; - let max_id = self.fields.iter().map(|&(id, _)| id).max().unwrap_or(0); - let id_size = int_size(max_id as usize); + let field_ids_by_sorted_field_name = self + .parent + .dict + .iter() + .filter_map(|(_, id)| self.fields.contains_key(id).then_some(*id)) + .collect::>(); + + let max_id = self.fields.keys().last().copied().unwrap_or(0) as usize; + + let id_size = int_size(max_id); let offset_size = int_size(data_size); let header_size = 1 @@ -542,17 +550,18 @@ impl<'a> ObjectBuilder<'a> { } // Write field IDs (sorted order) - for &(id, _) in &self.fields { + for id in &field_ids_by_sorted_field_name { write_offset( &mut self.parent.buffer[pos..pos + id_size as usize], - id as usize, + *id as usize, id_size, ); pos += id_size as usize; } // Write field offsets - for &(_, offset) in &self.fields { + for id in &field_ids_by_sorted_field_name { + let &offset = self.fields.get(id).unwrap(); write_offset( &mut self.parent.buffer[pos..pos + offset_size as usize], offset, @@ -749,6 +758,77 @@ mod tests { assert_eq!(field_ids, vec![1, 2, 0]); } + #[test] + fn test_object_and_metadata_ordering() { + let mut builder = VariantBuilder::new(); + + let mut obj = builder.new_object(); + + obj.append_value("zebra", "stripes"); // ID = 0 + obj.append_value("apple", "red"); // ID = 1 + + { + // fields_map is ordered by insertion order (field id) + let fields_map = obj.fields.keys().copied().collect::>(); + assert_eq!(fields_map, vec![0, 1]); + + // dict is ordered by field names + // NOTE: when we support nested objects, we'll want to perform a filter by fields_map field ids + let dict_metadata = obj + .parent + .dict + .iter() + .map(|(f, i)| (f.as_str(), *i)) + .collect::>(); + + assert_eq!(dict_metadata, vec![("apple", 1), ("zebra", 0)]); + + // dict_keys is ordered by insertion order (field id) + let dict_keys = obj + .parent + .dict_keys + .iter() + .map(|k| k.as_str()) + .collect::>(); + assert_eq!(dict_keys, vec!["zebra", "apple"]); + } + + obj.append_value("banana", "yellow"); // ID = 2 + + { + // fields_map is ordered by insertion order (field id) + let fields_map = obj.fields.keys().copied().collect::>(); + assert_eq!(fields_map, vec![0, 1, 2]); + + // dict is ordered by field names + // NOTE: when we support nested objects, we'll want to perform a filter by fields_map field ids + let dict_metadata = obj + .parent + .dict + .iter() + .map(|(f, i)| (f.as_str(), *i)) + .collect::>(); + + assert_eq!( + dict_metadata, + vec![("apple", 1), ("banana", 2), ("zebra", 0)] + ); + + // dict_keys is ordered by insertion order (field id) + let dict_keys = obj + .parent + .dict_keys + .iter() + .map(|k| k.as_str()) + .collect::>(); + assert_eq!(dict_keys, vec!["zebra", "apple", "banana"]); + } + + obj.finish(); + + builder.finish(); + } + #[test] fn test_append_object() { let (object_metadata, object_value) = { diff --git a/parquet-variant/src/to_json.rs b/parquet-variant/src/to_json.rs index 0cdcb8b49e63..80759b80a5c8 100644 --- a/parquet-variant/src/to_json.rs +++ b/parquet-variant/src/to_json.rs @@ -23,6 +23,7 @@ use serde_json::Value; use std::io::Write; use crate::variant::{Variant, VariantList, VariantObject}; +use crate::{VariantDecimal16, VariantDecimal4, VariantDecimal8}; // Format string constants to avoid duplication and reduce errors const DATE_FORMAT: &str = "%Y-%m-%d"; @@ -106,7 +107,7 @@ pub fn variant_to_json(json_buffer: &mut impl Write, variant: &Variant) -> Resul Variant::Double(f) => { write!(json_buffer, "{}", f)?; } - Variant::Decimal4 { integer, scale } => { + Variant::Decimal4(VariantDecimal4 { integer, scale }) => { // Convert decimal to string representation using integer arithmetic if *scale == 0 { write!(json_buffer, "{}", integer)?; @@ -123,7 +124,7 @@ pub fn variant_to_json(json_buffer: &mut impl Write, variant: &Variant) -> Resul } } } - Variant::Decimal8 { integer, scale } => { + Variant::Decimal8(VariantDecimal8 { integer, scale }) => { // Convert decimal to string representation using integer arithmetic if *scale == 0 { write!(json_buffer, "{}", integer)?; @@ -140,7 +141,7 @@ pub fn variant_to_json(json_buffer: &mut impl Write, variant: &Variant) -> Resul } } } - Variant::Decimal16 { integer, scale } => { + Variant::Decimal16(VariantDecimal16 { integer, scale }) => { // Convert decimal to string representation using integer arithmetic if *scale == 0 { write!(json_buffer, "{}", integer)?; @@ -364,7 +365,7 @@ pub fn variant_to_json_value(variant: &Variant) -> Result { Variant::Double(f) => serde_json::Number::from_f64(*f) .map(Value::Number) .ok_or_else(|| ArrowError::InvalidArgumentError("Invalid double value".to_string())), - Variant::Decimal4 { integer, scale } => { + Variant::Decimal4(VariantDecimal4 { integer, scale }) => { // Use integer arithmetic to avoid f64 precision loss if *scale == 0 { Ok(Value::Number((*integer).into())) @@ -390,7 +391,7 @@ pub fn variant_to_json_value(variant: &Variant) -> Result { }) } } - Variant::Decimal8 { integer, scale } => { + Variant::Decimal8(VariantDecimal8 { integer, scale }) => { // Use integer arithmetic to avoid f64 precision loss if *scale == 0 { Ok(Value::Number((*integer).into())) @@ -416,7 +417,7 @@ pub fn variant_to_json_value(variant: &Variant) -> Result { }) } } - Variant::Decimal16 { integer, scale } => { + Variant::Decimal16(VariantDecimal16 { integer, scale }) => { // Use integer arithmetic to avoid f64 precision loss if *scale == 0 { Ok(Value::Number((*integer as i64).into())) // Convert to i64 for JSON compatibility @@ -482,18 +483,12 @@ mod tests { #[test] fn test_decimal_edge_cases() -> Result<(), ArrowError> { // Test negative decimal - let negative_variant = Variant::Decimal4 { - integer: -12345, - scale: 3, - }; + let negative_variant = Variant::from(VariantDecimal4::try_new(-12345, 3)?); let negative_json = variant_to_json_string(&negative_variant)?; assert_eq!(negative_json, "-12.345"); // Test large scale decimal - let large_scale_variant = Variant::Decimal8 { - integer: 123456789, - scale: 6, - }; + let large_scale_variant = Variant::from(VariantDecimal8::try_new(123456789, 6)?); let large_scale_json = variant_to_json_string(&large_scale_variant)?; assert_eq!(large_scale_json, "123.456789"); @@ -502,10 +497,7 @@ mod tests { #[test] fn test_decimal16_to_json() -> Result<(), ArrowError> { - let variant = Variant::Decimal16 { - integer: 123456789012345, - scale: 4, - }; + let variant = Variant::from(VariantDecimal16::try_new(123456789012345, 4)?); let json = variant_to_json_string(&variant)?; assert_eq!(json, "12345678901.2345"); @@ -513,10 +505,7 @@ mod tests { assert!(matches!(json_value, Value::Number(_))); // Test very large number - let large_variant = Variant::Decimal16 { - integer: 999999999999999999, - scale: 2, - }; + let large_variant = Variant::from(VariantDecimal16::try_new(999999999999999999, 2)?); let large_json = variant_to_json_string(&large_variant)?; // Due to f64 precision limits, very large numbers may lose precision assert!( @@ -839,10 +828,7 @@ mod tests { // Decimals JsonTest { - variant: Variant::Decimal4 { - integer: 12345, - scale: 2, - }, + variant: Variant::from(VariantDecimal4::try_new(12345, 2).unwrap()), expected_json: "123.45", expected_value: serde_json::Number::from_f64(123.45) .map(Value::Number) @@ -851,10 +837,7 @@ mod tests { .run(); JsonTest { - variant: Variant::Decimal4 { - integer: 42, - scale: 0, - }, + variant: Variant::from(VariantDecimal4::try_new(42, 0).unwrap()), expected_json: "42", expected_value: serde_json::Number::from_f64(42.0) .map(Value::Number) @@ -863,10 +846,7 @@ mod tests { .run(); JsonTest { - variant: Variant::Decimal8 { - integer: 1234567890, - scale: 3, - }, + variant: Variant::from(VariantDecimal8::try_new(1234567890, 3).unwrap()), expected_json: "1234567.89", expected_value: serde_json::Number::from_f64(1234567.89) .map(Value::Number) @@ -875,10 +855,7 @@ mod tests { .run(); JsonTest { - variant: Variant::Decimal16 { - integer: 123456789012345, - scale: 4, - }, + variant: Variant::from(VariantDecimal16::try_new(123456789012345, 4).unwrap()), expected_json: "12345678901.2345", expected_value: serde_json::Number::from_f64(12345678901.2345) .map(Value::Number) @@ -1003,9 +980,7 @@ mod tests { // Parse the JSON to verify structure - handle JSON parsing errors manually let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Object(obj) = parsed else { - panic!("Expected JSON object"); - }; + let obj = parsed.as_object().expect("expected JSON object"); assert_eq!(obj.get("name"), Some(&Value::String("Alice".to_string()))); assert_eq!(obj.get("age"), Some(&Value::Number(30.into()))); assert_eq!(obj.get("active"), Some(&Value::Bool(true))); @@ -1094,9 +1069,7 @@ mod tests { assert_eq!(json, "[1,2,3,4,5]"); let json_value = variant_to_json_value(&variant)?; - let Value::Array(arr) = json_value else { - panic!("Expected JSON array"); - }; + let arr = json_value.as_array().expect("expected JSON array"); assert_eq!(arr.len(), 5); assert_eq!(arr[0], Value::Number(1.into())); assert_eq!(arr[4], Value::Number(5.into())); @@ -1148,9 +1121,7 @@ mod tests { let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Array(arr) = parsed else { - panic!("Expected JSON array"); - }; + let arr = parsed.as_array().expect("expected JSON array"); assert_eq!(arr.len(), 5); assert_eq!(arr[0], Value::String("hello".to_string())); assert_eq!(arr[1], Value::Number(42.into())); @@ -1183,9 +1154,7 @@ mod tests { // Parse and verify all fields are present let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Object(obj) = parsed else { - panic!("Expected JSON object"); - }; + let obj = parsed.as_object().expect("expected JSON object"); assert_eq!(obj.len(), 3); assert_eq!(obj.get("alpha"), Some(&Value::String("first".to_string()))); assert_eq!(obj.get("beta"), Some(&Value::String("second".to_string()))); @@ -1218,9 +1187,7 @@ mod tests { let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Array(arr) = parsed else { - panic!("Expected JSON array"); - }; + let arr = parsed.as_array().expect("expected JSON array"); assert_eq!(arr.len(), 7); assert_eq!(arr[0], Value::String("string_value".to_string())); assert_eq!(arr[1], Value::Number(42.into())); @@ -1256,9 +1223,7 @@ mod tests { let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Object(obj) = parsed else { - panic!("Expected JSON object"); - }; + let obj = parsed.as_object().expect("expected JSON object"); assert_eq!(obj.len(), 6); assert_eq!( obj.get("string_field"), @@ -1277,10 +1242,10 @@ mod tests { fn test_high_precision_decimal_no_loss() -> Result<(), ArrowError> { // Test case that would lose precision with f64 conversion // This is a 63-bit precision decimal8 value that f64 cannot represent exactly - let high_precision_decimal8 = Variant::Decimal8 { - integer: 9007199254740993, // 2^53 + 1, exceeds f64 precision - scale: 6, - }; + let high_precision_decimal8 = Variant::from(VariantDecimal8::try_new( + 9007199254740993, // 2^53 + 1, exceeds f64 precision + 6, + )?); let json_string = variant_to_json_string(&high_precision_decimal8)?; let json_value = variant_to_json_value(&high_precision_decimal8)?; @@ -1294,10 +1259,10 @@ mod tests { assert_eq!(parsed, json_value); // Test another case with trailing zeros that should be trimmed - let decimal_with_zeros = Variant::Decimal8 { - integer: 1234567890000, // Should result in 1234567.89 (trailing zeros trimmed) - scale: 6, - }; + let decimal_with_zeros = Variant::from(VariantDecimal8::try_new( + 1234567890000, // Should result in 1234567.89 (trailing zeros trimmed) + 6, + )?); let json_string_zeros = variant_to_json_string(&decimal_with_zeros)?; assert_eq!(json_string_zeros, "1234567.89"); diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 734f9fa23cfd..923faf8431c5 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -40,8 +40,100 @@ const MAX_SHORT_STRING_BYTES: usize = 0x3F; #[derive(Debug, Clone, Copy, PartialEq)] pub struct ShortString<'a>(pub(crate) &'a str); +/// Represents a 4-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 32-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is limited to 9 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal4 { + pub(crate) integer: i32, + pub(crate) scale: u8, +} + +impl VariantDecimal4 { + pub fn try_new(integer: i32, scale: u8) -> Result { + const PRECISION_MAX: u32 = 9; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 9 for 4-byte decimal", + scale + ))); + } + + Ok(VariantDecimal4 { integer, scale }) + } +} + +/// Represents an 8-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 64-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is between 10 and 18 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal8 { + pub(crate) integer: i64, + pub(crate) scale: u8, +} + +impl VariantDecimal8 { + pub fn try_new(integer: i64, scale: u8) -> Result { + const PRECISION_MAX: u32 = 18; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 18 for 8-byte decimal", + scale + ))); + } + + Ok(VariantDecimal8 { integer, scale }) + } +} + +/// Represents an 16-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 128-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is between 19 and 38 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal16 { + pub(crate) integer: i128, + pub(crate) scale: u8, +} + +impl VariantDecimal16 { + pub fn try_new(integer: i128, scale: u8) -> Result { + const PRECISION_MAX: u32 = 38; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 38 for 16-byte decimal", + scale + ))); + } + + Ok(VariantDecimal16 { integer, scale }) + } +} + impl<'a> ShortString<'a> { - /// Attempts to interpret `value` as a variant short string value. + /// Attempts to interpret `value` as a variant short string value. /// /// # Validation /// @@ -194,11 +286,11 @@ pub enum Variant<'m, 'v> { /// Primitive (type_id=1): TIMESTAMP(isAdjustedToUTC=false, MICROS) TimestampNtzMicros(NaiveDateTime), /// Primitive (type_id=1): DECIMAL(precision, scale) 32-bits - Decimal4 { integer: i32, scale: u8 }, + Decimal4(VariantDecimal4), /// Primitive (type_id=1): DECIMAL(precision, scale) 64-bits - Decimal8 { integer: i64, scale: u8 }, + Decimal8(VariantDecimal8), /// Primitive (type_id=1): DECIMAL(precision, scale) 128-bits - Decimal16 { integer: i128, scale: u8 }, + Decimal16(VariantDecimal16), /// Primitive (type_id=1): FLOAT Float(f32), /// Primitive (type_id=1): DOUBLE @@ -269,15 +361,15 @@ impl<'m, 'v> Variant<'m, 'v> { VariantPrimitiveType::Int64 => Variant::Int64(decoder::decode_int64(value_data)?), VariantPrimitiveType::Decimal4 => { let (integer, scale) = decoder::decode_decimal4(value_data)?; - Variant::Decimal4 { integer, scale } + Variant::Decimal4(VariantDecimal4 { integer, scale }) } VariantPrimitiveType::Decimal8 => { let (integer, scale) = decoder::decode_decimal8(value_data)?; - Variant::Decimal8 { integer, scale } + Variant::Decimal8(VariantDecimal8 { integer, scale }) } VariantPrimitiveType::Decimal16 => { let (integer, scale) = decoder::decode_decimal16(value_data)?; - Variant::Decimal16 { integer, scale } + Variant::Decimal16(VariantDecimal16 { integer, scale }) } VariantPrimitiveType::Float => Variant::Float(decoder::decode_float(value_data)?), VariantPrimitiveType::Double => { @@ -640,18 +732,18 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i32, 2)); + /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap()); /// assert_eq!(v1.as_decimal_int32(), Some((1234_i32, 2))); /// /// // and from larger decimal variants if they fit - /// let v2 = Variant::from((1234_i64, 2)); + /// let v2 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap()); /// assert_eq!(v2.as_decimal_int32(), Some((1234_i32, 2))); /// /// // but not if the value would overflow i32 - /// let v3 = Variant::from((12345678901i64, 2)); + /// let v3 = Variant::from(VariantDecimal8::try_new(12345678901i64, 2).unwrap()); /// assert_eq!(v3.as_decimal_int32(), None); /// /// // or if the variant is not a decimal @@ -660,17 +752,17 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int32(&self) -> Option<(i32, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer, scale)), - Variant::Decimal8 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal4(decimal4) => Some((decimal4.integer, decimal4.scale)), + Variant::Decimal8(decimal8) => { + if let Ok(converted_integer) = decimal8.integer.try_into() { + Some((converted_integer, decimal8.scale)) } else { None } } - Variant::Decimal16 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal16(decimal16) => { + if let Ok(converted_integer) = decimal16.integer.try_into() { + Some((converted_integer, decimal16.scale)) } else { None } @@ -688,18 +780,18 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal8, VariantDecimal16}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i64, 2)); + /// let v1 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap()); /// assert_eq!(v1.as_decimal_int64(), Some((1234_i64, 2))); /// /// // and from larger decimal variants if they fit - /// let v2 = Variant::from((1234_i128, 2)); + /// let v2 = Variant::from(VariantDecimal16::try_new(1234_i128, 2).unwrap()); /// assert_eq!(v2.as_decimal_int64(), Some((1234_i64, 2))); /// /// // but not if the value would overflow i64 - /// let v3 = Variant::from((2e19 as i128, 2)); + /// let v3 = Variant::from(VariantDecimal16::try_new(2e19 as i128, 2).unwrap()); /// assert_eq!(v3.as_decimal_int64(), None); /// /// // or if the variant is not a decimal @@ -708,11 +800,11 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int64(&self) -> Option<(i64, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal8 { integer, scale } => Some((integer, scale)), - Variant::Decimal16 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal4(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal8(decimal) => Some((decimal.integer, decimal.scale)), + Variant::Decimal16(decimal) => { + if let Ok(converted_integer) = decimal.integer.try_into() { + Some((converted_integer, decimal.scale)) } else { None } @@ -730,10 +822,10 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal16}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i128, 2)); + /// let v1 = Variant::from(VariantDecimal16::try_new(1234_i128, 2).unwrap()); /// assert_eq!(v1.as_decimal_int128(), Some((1234_i128, 2))); /// /// // but not if the variant is not a decimal @@ -742,9 +834,9 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int128(&self) -> Option<(i128, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal8 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal16 { integer, scale } => Some((integer, scale)), + Variant::Decimal4(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal8(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal16(decimal) => Some((decimal.integer, decimal.scale)), _ => None, } } @@ -809,6 +901,70 @@ impl<'m, 'v> Variant<'m, 'v> { } } + /// Converts this variant to an `Object` if it is an [`VariantObject`]. + /// + /// Returns `Some(&VariantObject)` for object variants, + /// `None` for non-object variants. + /// + /// # Examples + /// ``` + /// # use parquet_variant::{Variant, VariantBuilder, VariantObject}; + /// # let (metadata, value) = { + /// # let mut builder = VariantBuilder::new(); + /// # let mut obj = builder.new_object(); + /// # obj.append_value("name", "John"); + /// # obj.finish(); + /// # builder.finish() + /// # }; + /// // object that is {"name": "John"} + /// let variant = Variant::try_new(&metadata, &value).unwrap(); + /// // use the `as_object` method to access the object + /// let obj = variant.as_object().expect("variant should be an object"); + /// assert_eq!(obj.field_by_name("name").unwrap(), Some(Variant::from("John"))); + /// ``` + pub fn as_object(&'m self) -> Option<&'m VariantObject<'m, 'v>> { + if let Variant::Object(obj) = self { + Some(obj) + } else { + None + } + } + + /// Converts this variant to a `List` if it is a [`VariantList`]. + /// + /// Returns `Some(&VariantList)` for list variants, + /// `None` for non-list variants. + /// + /// # Examples + /// ``` + /// # use parquet_variant::{Variant, VariantBuilder, VariantList}; + /// # let (metadata, value) = { + /// # let mut builder = VariantBuilder::new(); + /// # let mut list = builder.new_list(); + /// # list.append_value("John"); + /// # list.append_value("Doe"); + /// # list.finish(); + /// # builder.finish() + /// # }; + /// // list that is ["John", "Doe"] + /// let variant = Variant::try_new(&metadata, &value).unwrap(); + /// // use the `as_list` method to access the list + /// let list = variant.as_list().expect("variant should be a list"); + /// assert_eq!(list.len(), 2); + /// assert_eq!(list.get(0).unwrap(), Variant::from("John")); + /// assert_eq!(list.get(1).unwrap(), Variant::from("Doe")); + /// ``` + pub fn as_list(&'m self) -> Option<&'m VariantList<'m, 'v>> { + if let Variant::List(list) = self { + Some(list) + } else { + None + } + } + + /// Return the metadata associated with this variant, if any. + /// + /// Returns `Some(&VariantMetadata)` for object and list variants, pub fn metadata(&self) -> Option<&'m VariantMetadata> { match self { Variant::Object(VariantObject { metadata, .. }) @@ -848,12 +1004,9 @@ impl From for Variant<'_, '_> { } } -impl From<(i32, u8)> for Variant<'_, '_> { - fn from(value: (i32, u8)) -> Self { - Variant::Decimal4 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal4) -> Self { + Variant::Decimal4(value) } } impl From for Variant<'_, '_> { @@ -865,21 +1018,15 @@ impl From for Variant<'_, '_> { } } -impl From<(i64, u8)> for Variant<'_, '_> { - fn from(value: (i64, u8)) -> Self { - Variant::Decimal8 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal8) -> Self { + Variant::Decimal8(value) } } -impl From<(i128, u8)> for Variant<'_, '_> { - fn from(value: (i128, u8)) -> Self { - Variant::Decimal16 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal16) -> Self { + Variant::Decimal16(value) } } @@ -928,6 +1075,36 @@ impl<'v> From<&'v str> for Variant<'_, 'v> { } } +impl TryFrom<(i32, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i32, u8)) -> Result { + Ok(Variant::Decimal4(VariantDecimal4::try_new( + value.0, value.1, + )?)) + } +} + +impl TryFrom<(i64, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i64, u8)) -> Result { + Ok(Variant::Decimal8(VariantDecimal8::try_new( + value.0, value.1, + )?)) + } +} + +impl TryFrom<(i128, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i128, u8)) -> Result { + Ok(Variant::Decimal16(VariantDecimal16::try_new( + value.0, value.1, + )?)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -941,4 +1118,28 @@ mod tests { let res = ShortString::try_new(&long_string); assert!(res.is_err()); } + + #[test] + fn test_variant_decimal_conversion() { + let decimal4 = VariantDecimal4::try_new(1234_i32, 2).unwrap(); + let variant = Variant::from(decimal4); + assert_eq!(variant.as_decimal_int32(), Some((1234_i32, 2))); + + let decimal8 = VariantDecimal8::try_new(12345678901_i64, 2).unwrap(); + let variant = Variant::from(decimal8); + assert_eq!(variant.as_decimal_int64(), Some((12345678901_i64, 2))); + + let decimal16 = VariantDecimal16::try_new(123456789012345678901234567890_i128, 2).unwrap(); + let variant = Variant::from(decimal16); + assert_eq!( + variant.as_decimal_int128(), + Some((123456789012345678901234567890_i128, 2)) + ); + } + + #[test] + fn test_invalid_variant_decimal_conversion() { + let decimal4 = VariantDecimal4::try_new(123456789_i32, 20); + assert!(decimal4.is_err(), "i32 overflow should fail"); + } } diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index bfa2ab267c27..be63357422e4 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -24,7 +24,9 @@ use std::fs; use std::path::{Path, PathBuf}; use chrono::NaiveDate; -use parquet_variant::{ShortString, Variant, VariantBuilder}; +use parquet_variant::{ + ShortString, Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8, +}; fn cases_dir() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")) @@ -63,9 +65,10 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { ("primitive_boolean_false", Variant::BooleanFalse), ("primitive_boolean_true", Variant::BooleanTrue), ("primitive_date", Variant::Date(NaiveDate::from_ymd_opt(2025, 4 , 16).unwrap())), - ("primitive_decimal4", Variant::Decimal4{integer: 1234, scale: 2}), - ("primitive_decimal8", Variant::Decimal8{integer: 1234567890, scale: 2}), - ("primitive_decimal16", Variant::Decimal16{integer: 1234567891234567890, scale: 2}), + ("primitive_decimal4", Variant::from(VariantDecimal4::try_new(1234i32, 2u8).unwrap())), + // ("primitive_decimal8", Variant::Decimal8{integer: 1234567890, scale: 2}), + ("primitive_decimal8", Variant::Decimal8(VariantDecimal8::try_new(1234567890,2).unwrap())), + ("primitive_decimal16", Variant::Decimal16(VariantDecimal16::try_new(1234567891234567890, 2).unwrap())), ("primitive_float", Variant::Float(1234567890.1234)), ("primitive_double", Variant::Double(1234567890.1234)), ("primitive_int8", Variant::Int8(42)), @@ -123,10 +126,7 @@ fn variant_object_primitive() { // spark wrote this as a decimal4 (not a double) ( "double_field", - Variant::Decimal4 { - integer: 123456789, - scale: 8, - }, + Variant::Decimal4(VariantDecimal4::try_new(123456789, 8).unwrap()), ), ("int_field", Variant::Int8(1)), ("null_field", Variant::Null), @@ -210,7 +210,10 @@ fn variant_object_builder() { // The double field is actually encoded as decimal4 with scale 8 // Value: 123456789, Scale: 8 -> 1.23456789 - obj.append_value("double_field", (123456789i32, 8u8)); + obj.append_value( + "double_field", + VariantDecimal4::try_new(123456789i32, 8u8).unwrap(), + ); obj.append_value("boolean_true_field", true); obj.append_value("boolean_false_field", false); obj.append_value("string_field", "Apache Parquet");