@@ -25,15 +25,18 @@ use arrow::array::{
25
25
} ;
26
26
use arrow:: buffer:: NullBuffer ;
27
27
use arrow:: datatypes:: {
28
- ArrowPrimitiveType , DataType , Date32Type , Float32Type , Float64Type , Int8Type , Int16Type ,
29
- Int32Type , Int64Type , Time64MicrosecondType , TimeUnit , TimestampMicrosecondType ,
30
- TimestampNanosecondType ,
28
+ ArrowPrimitiveType , DataType , Date32Type , Decimal32Type , Decimal64Type , Decimal128Type ,
29
+ DecimalType , Float32Type , Float64Type , Int8Type , Int16Type , Int32Type , Int64Type ,
30
+ Time64MicrosecondType , TimeUnit , TimestampMicrosecondType , TimestampNanosecondType ,
31
31
} ;
32
32
use arrow:: error:: { ArrowError , Result } ;
33
33
use arrow:: temporal_conversions:: time64us_to_time;
34
34
use chrono:: { DateTime , Utc } ;
35
35
use indexmap:: IndexMap ;
36
- use parquet_variant:: { ObjectFieldBuilder , Variant , VariantBuilderExt , VariantMetadata } ;
36
+ use parquet_variant:: {
37
+ ObjectFieldBuilder , Variant , VariantBuilderExt , VariantDecimal4 , VariantDecimal8 ,
38
+ VariantDecimal16 , VariantMetadata ,
39
+ } ;
37
40
use uuid:: Uuid ;
38
41
39
42
/// Removes all (nested) typed_value columns from a VariantArray by converting them back to binary
@@ -92,6 +95,9 @@ enum UnshredVariantRowBuilder<'a> {
92
95
PrimitiveInt64 ( UnshredPrimitiveRowBuilder < ' a , PrimitiveArray < Int64Type > > ) ,
93
96
PrimitiveFloat32 ( UnshredPrimitiveRowBuilder < ' a , PrimitiveArray < Float32Type > > ) ,
94
97
PrimitiveFloat64 ( UnshredPrimitiveRowBuilder < ' a , PrimitiveArray < Float64Type > > ) ,
98
+ Decimal32 ( DecimalUnshredRowBuilder < ' a , Decimal32Spec > ) ,
99
+ Decimal64 ( DecimalUnshredRowBuilder < ' a , Decimal64Spec > ) ,
100
+ Decimal128 ( DecimalUnshredRowBuilder < ' a , Decimal128Spec > ) ,
95
101
PrimitiveDate32 ( UnshredPrimitiveRowBuilder < ' a , PrimitiveArray < Date32Type > > ) ,
96
102
PrimitiveTime64 ( UnshredPrimitiveRowBuilder < ' a , PrimitiveArray < Time64MicrosecondType > > ) ,
97
103
TimestampMicrosecond ( TimestampUnshredRowBuilder < ' a , TimestampMicrosecondType > ) ,
@@ -130,6 +136,9 @@ impl<'a> UnshredVariantRowBuilder<'a> {
130
136
Self :: PrimitiveInt64 ( b) => b. append_row ( builder, metadata, index) ,
131
137
Self :: PrimitiveFloat32 ( b) => b. append_row ( builder, metadata, index) ,
132
138
Self :: PrimitiveFloat64 ( b) => b. append_row ( builder, metadata, index) ,
139
+ Self :: Decimal32 ( b) => b. append_row ( builder, metadata, index) ,
140
+ Self :: Decimal64 ( b) => b. append_row ( builder, metadata, index) ,
141
+ Self :: Decimal128 ( b) => b. append_row ( builder, metadata, index) ,
133
142
Self :: PrimitiveDate32 ( b) => b. append_row ( builder, metadata, index) ,
134
143
Self :: PrimitiveTime64 ( b) => b. append_row ( builder, metadata, index) ,
135
144
Self :: TimestampMicrosecond ( b) => b. append_row ( builder, metadata, index) ,
@@ -176,6 +185,26 @@ impl<'a> UnshredVariantRowBuilder<'a> {
176
185
DataType :: Int64 => primitive_builder ! ( PrimitiveInt64 , as_primitive) ,
177
186
DataType :: Float32 => primitive_builder ! ( PrimitiveFloat32 , as_primitive) ,
178
187
DataType :: Float64 => primitive_builder ! ( PrimitiveFloat64 , as_primitive) ,
188
+ DataType :: Decimal32 ( _, scale) => Self :: Decimal32 ( DecimalUnshredRowBuilder :: new (
189
+ value,
190
+ typed_value. as_primitive ( ) ,
191
+ * scale,
192
+ ) ) ,
193
+ DataType :: Decimal64 ( _, scale) => Self :: Decimal64 ( DecimalUnshredRowBuilder :: new (
194
+ value,
195
+ typed_value. as_primitive ( ) ,
196
+ * scale,
197
+ ) ) ,
198
+ DataType :: Decimal128 ( _, scale) => Self :: Decimal128 ( DecimalUnshredRowBuilder :: new (
199
+ value,
200
+ typed_value. as_primitive ( ) ,
201
+ * scale,
202
+ ) ) ,
203
+ DataType :: Decimal256 ( _, _) => {
204
+ return Err ( ArrowError :: InvalidArgumentError (
205
+ "Decimal256 is not a valid variant shredding type" . to_string ( ) ,
206
+ ) ) ;
207
+ }
179
208
DataType :: Date32 => primitive_builder ! ( PrimitiveDate32 , as_primitive) ,
180
209
DataType :: Time64 ( TimeUnit :: Microsecond ) => {
181
210
primitive_builder ! ( PrimitiveTime64 , as_primitive)
@@ -475,6 +504,96 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> {
475
504
}
476
505
}
477
506
507
+ /// Trait to unify decimal unshredding across Decimal32/64/128 types
508
+ trait DecimalSpec {
509
+ type Arrow : ArrowPrimitiveType + DecimalType ;
510
+
511
+ fn into_variant (
512
+ raw : <Self :: Arrow as ArrowPrimitiveType >:: Native ,
513
+ scale : i8 ,
514
+ ) -> Result < Variant < ' static , ' static > > ;
515
+ }
516
+
517
+ /// Spec for Decimal32 -> VariantDecimal4
518
+ struct Decimal32Spec ;
519
+
520
+ impl DecimalSpec for Decimal32Spec {
521
+ type Arrow = Decimal32Type ;
522
+
523
+ fn into_variant ( raw : i32 , scale : i8 ) -> Result < Variant < ' static , ' static > > {
524
+ let scale =
525
+ u8:: try_from ( scale) . map_err ( |e| ArrowError :: InvalidArgumentError ( e. to_string ( ) ) ) ?;
526
+ let value = VariantDecimal4 :: try_new ( raw, scale)
527
+ . map_err ( |e| ArrowError :: InvalidArgumentError ( e. to_string ( ) ) ) ?;
528
+ Ok ( value. into ( ) )
529
+ }
530
+ }
531
+
532
+ /// Spec for Decimal64 -> VariantDecimal8
533
+ struct Decimal64Spec ;
534
+
535
+ impl DecimalSpec for Decimal64Spec {
536
+ type Arrow = Decimal64Type ;
537
+
538
+ fn into_variant ( raw : i64 , scale : i8 ) -> Result < Variant < ' static , ' static > > {
539
+ let scale =
540
+ u8:: try_from ( scale) . map_err ( |e| ArrowError :: InvalidArgumentError ( e. to_string ( ) ) ) ?;
541
+ let value = VariantDecimal8 :: try_new ( raw, scale)
542
+ . map_err ( |e| ArrowError :: InvalidArgumentError ( e. to_string ( ) ) ) ?;
543
+ Ok ( value. into ( ) )
544
+ }
545
+ }
546
+
547
+ /// Spec for Decimal128 -> VariantDecimal16
548
+ struct Decimal128Spec ;
549
+
550
+ impl DecimalSpec for Decimal128Spec {
551
+ type Arrow = Decimal128Type ;
552
+
553
+ fn into_variant ( raw : i128 , scale : i8 ) -> Result < Variant < ' static , ' static > > {
554
+ let scale =
555
+ u8:: try_from ( scale) . map_err ( |e| ArrowError :: InvalidArgumentError ( e. to_string ( ) ) ) ?;
556
+ let value = VariantDecimal16 :: try_new ( raw, scale)
557
+ . map_err ( |e| ArrowError :: InvalidArgumentError ( e. to_string ( ) ) ) ?;
558
+ Ok ( value. into ( ) )
559
+ }
560
+ }
561
+
562
+ /// Generic builder for decimal unshredding that caches scale
563
+ struct DecimalUnshredRowBuilder < ' a , S : DecimalSpec > {
564
+ value : Option < & ' a BinaryViewArray > ,
565
+ typed_value : & ' a PrimitiveArray < S :: Arrow > ,
566
+ scale : i8 ,
567
+ }
568
+
569
+ impl < ' a , S : DecimalSpec > DecimalUnshredRowBuilder < ' a , S > {
570
+ fn new (
571
+ value : Option < & ' a BinaryViewArray > ,
572
+ typed_value : & ' a PrimitiveArray < S :: Arrow > ,
573
+ scale : i8 ,
574
+ ) -> Self {
575
+ Self {
576
+ value,
577
+ typed_value,
578
+ scale,
579
+ }
580
+ }
581
+
582
+ fn append_row (
583
+ & mut self ,
584
+ builder : & mut impl VariantBuilderExt ,
585
+ metadata : & VariantMetadata ,
586
+ index : usize ,
587
+ ) -> Result < ( ) > {
588
+ handle_unshredded_case ! ( self , builder, metadata, index, false ) ;
589
+
590
+ let raw = self . typed_value . value ( index) ;
591
+ let variant = S :: into_variant ( raw, self . scale ) ?;
592
+ builder. append_value ( variant) ;
593
+ Ok ( ( ) )
594
+ }
595
+ }
596
+
478
597
/// Builder for unshredding struct/object types with nested fields
479
598
struct StructUnshredVariantBuilder < ' a > {
480
599
value : Option < & ' a arrow:: array:: BinaryViewArray > ,
0 commit comments