Skip to content

Commit 8e5d826

Browse files
authored
[Variant] Decimal unshredding support (#8540)
# Which issue does this PR close? - Closes #8332 # Rationale for this change Missing feature # What changes are included in this PR? Add decimal unshredding support, which _should_ have been straightforward except: 1. The variant decimal types are not generic and do not implement any common trait that lets us generalize the logic easily. I added a custom trait in the unshredding module as a workaround, but we should probably look at something similar to arrow's `DecimalType` trait for `VariantDecimalXX` classes to implement. 2. The parquet reader seems to have a bug (feature?) that forces 32- and 64-bit decimal columns to Decimal128 unless the reader specifically requests a narrower type. Which causes the variant decimal integration tests to fail because they receive `Variant::Decimal16` values when they expected `Variant::Decimal4` or `Variant::Decimal8` (the actual values are correct). Rather than directly tackle the bug in arrow-parquet itself (which has a large blast radius), I updated `VariantArray` constructor to cast such columns back to the correct type as needed. # Are these changes tested? Yes. The variant decimal integration tests now pass where they used to fail. # Are there any user-facing changes? No.
1 parent 2d900a4 commit 8e5d826

File tree

3 files changed

+152
-33
lines changed

3 files changed

+152
-33
lines changed

parquet-variant-compute/src/unshred_variant.rs

Lines changed: 123 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,18 @@ use arrow::array::{
2525
};
2626
use arrow::buffer::NullBuffer;
2727
use arrow::datatypes::{
28-
ArrowPrimitiveType, DataType, Date32Type, Float32Type, Float64Type, Int8Type, Int16Type,
29-
Int32Type, Int64Type, Time64MicrosecondType, TimeUnit, TimestampMicrosecondType,
30-
TimestampNanosecondType,
28+
ArrowPrimitiveType, DataType, Date32Type, Decimal32Type, Decimal64Type, Decimal128Type,
29+
DecimalType, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type,
30+
Time64MicrosecondType, TimeUnit, TimestampMicrosecondType, TimestampNanosecondType,
3131
};
3232
use arrow::error::{ArrowError, Result};
3333
use arrow::temporal_conversions::time64us_to_time;
3434
use chrono::{DateTime, Utc};
3535
use indexmap::IndexMap;
36-
use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt, VariantMetadata};
36+
use parquet_variant::{
37+
ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8,
38+
VariantDecimal16, VariantMetadata,
39+
};
3740
use uuid::Uuid;
3841

3942
/// Removes all (nested) typed_value columns from a VariantArray by converting them back to binary
@@ -92,6 +95,9 @@ enum UnshredVariantRowBuilder<'a> {
9295
PrimitiveInt64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Int64Type>>),
9396
PrimitiveFloat32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Float32Type>>),
9497
PrimitiveFloat64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Float64Type>>),
98+
Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Spec>),
99+
Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Spec>),
100+
Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Spec>),
95101
PrimitiveDate32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Date32Type>>),
96102
PrimitiveTime64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Time64MicrosecondType>>),
97103
TimestampMicrosecond(TimestampUnshredRowBuilder<'a, TimestampMicrosecondType>),
@@ -130,6 +136,9 @@ impl<'a> UnshredVariantRowBuilder<'a> {
130136
Self::PrimitiveInt64(b) => b.append_row(builder, metadata, index),
131137
Self::PrimitiveFloat32(b) => b.append_row(builder, metadata, index),
132138
Self::PrimitiveFloat64(b) => b.append_row(builder, metadata, index),
139+
Self::Decimal32(b) => b.append_row(builder, metadata, index),
140+
Self::Decimal64(b) => b.append_row(builder, metadata, index),
141+
Self::Decimal128(b) => b.append_row(builder, metadata, index),
133142
Self::PrimitiveDate32(b) => b.append_row(builder, metadata, index),
134143
Self::PrimitiveTime64(b) => b.append_row(builder, metadata, index),
135144
Self::TimestampMicrosecond(b) => b.append_row(builder, metadata, index),
@@ -176,6 +185,26 @@ impl<'a> UnshredVariantRowBuilder<'a> {
176185
DataType::Int64 => primitive_builder!(PrimitiveInt64, as_primitive),
177186
DataType::Float32 => primitive_builder!(PrimitiveFloat32, as_primitive),
178187
DataType::Float64 => primitive_builder!(PrimitiveFloat64, as_primitive),
188+
DataType::Decimal32(_, scale) => Self::Decimal32(DecimalUnshredRowBuilder::new(
189+
value,
190+
typed_value.as_primitive(),
191+
*scale,
192+
)),
193+
DataType::Decimal64(_, scale) => Self::Decimal64(DecimalUnshredRowBuilder::new(
194+
value,
195+
typed_value.as_primitive(),
196+
*scale,
197+
)),
198+
DataType::Decimal128(_, scale) => Self::Decimal128(DecimalUnshredRowBuilder::new(
199+
value,
200+
typed_value.as_primitive(),
201+
*scale,
202+
)),
203+
DataType::Decimal256(_, _) => {
204+
return Err(ArrowError::InvalidArgumentError(
205+
"Decimal256 is not a valid variant shredding type".to_string(),
206+
));
207+
}
179208
DataType::Date32 => primitive_builder!(PrimitiveDate32, as_primitive),
180209
DataType::Time64(TimeUnit::Microsecond) => {
181210
primitive_builder!(PrimitiveTime64, as_primitive)
@@ -475,6 +504,96 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> {
475504
}
476505
}
477506

507+
/// Trait to unify decimal unshredding across Decimal32/64/128 types
508+
trait DecimalSpec {
509+
type Arrow: ArrowPrimitiveType + DecimalType;
510+
511+
fn into_variant(
512+
raw: <Self::Arrow as ArrowPrimitiveType>::Native,
513+
scale: i8,
514+
) -> Result<Variant<'static, 'static>>;
515+
}
516+
517+
/// Spec for Decimal32 -> VariantDecimal4
518+
struct Decimal32Spec;
519+
520+
impl DecimalSpec for Decimal32Spec {
521+
type Arrow = Decimal32Type;
522+
523+
fn into_variant(raw: i32, scale: i8) -> Result<Variant<'static, 'static>> {
524+
let scale =
525+
u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
526+
let value = VariantDecimal4::try_new(raw, scale)
527+
.map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
528+
Ok(value.into())
529+
}
530+
}
531+
532+
/// Spec for Decimal64 -> VariantDecimal8
533+
struct Decimal64Spec;
534+
535+
impl DecimalSpec for Decimal64Spec {
536+
type Arrow = Decimal64Type;
537+
538+
fn into_variant(raw: i64, scale: i8) -> Result<Variant<'static, 'static>> {
539+
let scale =
540+
u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
541+
let value = VariantDecimal8::try_new(raw, scale)
542+
.map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
543+
Ok(value.into())
544+
}
545+
}
546+
547+
/// Spec for Decimal128 -> VariantDecimal16
548+
struct Decimal128Spec;
549+
550+
impl DecimalSpec for Decimal128Spec {
551+
type Arrow = Decimal128Type;
552+
553+
fn into_variant(raw: i128, scale: i8) -> Result<Variant<'static, 'static>> {
554+
let scale =
555+
u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
556+
let value = VariantDecimal16::try_new(raw, scale)
557+
.map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
558+
Ok(value.into())
559+
}
560+
}
561+
562+
/// Generic builder for decimal unshredding that caches scale
563+
struct DecimalUnshredRowBuilder<'a, S: DecimalSpec> {
564+
value: Option<&'a BinaryViewArray>,
565+
typed_value: &'a PrimitiveArray<S::Arrow>,
566+
scale: i8,
567+
}
568+
569+
impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> {
570+
fn new(
571+
value: Option<&'a BinaryViewArray>,
572+
typed_value: &'a PrimitiveArray<S::Arrow>,
573+
scale: i8,
574+
) -> Self {
575+
Self {
576+
value,
577+
typed_value,
578+
scale,
579+
}
580+
}
581+
582+
fn append_row(
583+
&mut self,
584+
builder: &mut impl VariantBuilderExt,
585+
metadata: &VariantMetadata,
586+
index: usize,
587+
) -> Result<()> {
588+
handle_unshredded_case!(self, builder, metadata, index, false);
589+
590+
let raw = self.typed_value.value(index);
591+
let variant = S::into_variant(raw, self.scale)?;
592+
builder.append_value(variant);
593+
Ok(())
594+
}
595+
}
596+
478597
/// Builder for unshredding struct/object types with nested fields
479598
struct StructUnshredVariantBuilder<'a> {
480599
value: Option<&'a arrow::array::BinaryViewArray>,

parquet-variant-compute/src/variant_array.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ use arrow::datatypes::{
2626
TimestampMicrosecondType, TimestampNanosecondType,
2727
};
2828
use arrow_schema::extension::ExtensionType;
29-
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, TimeUnit};
29+
use arrow_schema::{
30+
ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION,
31+
DataType, Field, FieldRef, Fields, TimeUnit,
32+
};
3033
use chrono::DateTime;
3134
use parquet_variant::Uuid;
3235
use parquet_variant::Variant;
@@ -926,6 +929,11 @@ fn typed_value_to_variant<'a>(
926929
/// So cast them to get the right type.
927930
fn cast_to_binary_view_arrays(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
928931
let new_type = canonicalize_and_verify_data_type(array.data_type())?;
932+
if let Cow::Borrowed(_) = new_type {
933+
if let Some(array) = array.as_struct_opt() {
934+
return Ok(Arc::new(array.clone())); // bypass the unnecessary cast
935+
}
936+
}
929937
cast(array, new_type.as_ref())
930938
}
931939

@@ -972,9 +980,20 @@ fn canonicalize_and_verify_data_type(
972980
UInt8 | UInt16 | UInt32 | UInt64 | Float16 => fail!(),
973981

974982
// Most decimal types are allowed, with restrictions on precision and scale
975-
Decimal32(p, s) if is_valid_variant_decimal(p, s, 9) => borrow!(),
976-
Decimal64(p, s) if is_valid_variant_decimal(p, s, 18) => borrow!(),
977-
Decimal128(p, s) if is_valid_variant_decimal(p, s, 38) => borrow!(),
983+
//
984+
// NOTE: arrow-parquet reads widens 32- and 64-bit decimals to 128-bit, but the variant spec
985+
// requires using the narrowest decimal type for a given precision. Fix those up first.
986+
Decimal64(p, s) | Decimal128(p, s)
987+
if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) =>
988+
{
989+
Cow::Owned(Decimal32(*p, *s))
990+
}
991+
Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => {
992+
Cow::Owned(Decimal64(*p, *s))
993+
}
994+
Decimal32(p, s) if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) => borrow!(),
995+
Decimal64(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => borrow!(),
996+
Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL128_MAX_PRECISION) => borrow!(),
978997
Decimal32(..) | Decimal64(..) | Decimal128(..) | Decimal256(..) => fail!(),
979998

980999
// Only micro and nano timestamps are allowed

parquet/tests/variant_integration.rs

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,31 +86,12 @@ variant_test_case!(20);
8686
variant_test_case!(21);
8787
variant_test_case!(22);
8888
variant_test_case!(23);
89-
// https://github.com/apache/arrow-rs/issues/8332
90-
variant_test_case!(
91-
24,
92-
"Unshredding not yet supported for type: Decimal128(9, 4)"
93-
);
94-
variant_test_case!(
95-
25,
96-
"Unshredding not yet supported for type: Decimal128(9, 4)"
97-
);
98-
variant_test_case!(
99-
26,
100-
"Unshredding not yet supported for type: Decimal128(18, 9)"
101-
);
102-
variant_test_case!(
103-
27,
104-
"Unshredding not yet supported for type: Decimal128(18, 9)"
105-
);
106-
variant_test_case!(
107-
28,
108-
"Unshredding not yet supported for type: Decimal128(38, 9)"
109-
);
110-
variant_test_case!(
111-
29,
112-
"Unshredding not yet supported for type: Decimal128(38, 9)"
113-
);
89+
variant_test_case!(24);
90+
variant_test_case!(25);
91+
variant_test_case!(26);
92+
variant_test_case!(27);
93+
variant_test_case!(28);
94+
variant_test_case!(29);
11495
variant_test_case!(30);
11596
variant_test_case!(31);
11697
variant_test_case!(32);

0 commit comments

Comments
 (0)