Skip to content

Commit 55a5863

Browse files
authored
generate parquet schema from rust struct (apache#539)
* generate parquet schema from rust struct * support all primitive types through logical types
1 parent b05edf4 commit 55a5863

File tree

4 files changed

+236
-41
lines changed

4 files changed

+236
-41
lines changed

parquet/src/record/record_writer.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::schema::types::TypePtr;
19+
1820
use super::super::errors::ParquetError;
1921
use super::super::file::writer::RowGroupWriter;
2022

@@ -23,4 +25,7 @@ pub trait RecordWriter<T> {
2325
&self,
2426
row_group_writer: &mut Box<dyn RowGroupWriter>,
2527
) -> Result<(), ParquetError>;
28+
29+
/// Generated schema
30+
fn schema(&self) -> Result<TypePtr, ParquetError>;
2631
}

parquet_derive/src/lib.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ mod parquet_field;
5252
/// pub a_str: &'a str,
5353
/// }
5454
///
55-
/// let schema_str = "message schema {
56-
/// REQUIRED boolean a_bool;
57-
/// REQUIRED BINARY a_str (UTF8);
58-
/// }";
59-
///
6055
/// pub fn write_some_records() {
6156
/// let samples = vec![
6257
/// ACompleteRecord {
@@ -69,7 +64,7 @@ mod parquet_field;
6964
/// }
7065
/// ];
7166
///
72-
/// let schema = Arc::new(parse_message_type(schema_str).unwrap());
67+
/// let schema = samples.as_slice().schema();
7368
///
7469
/// let props = Arc::new(WriterProperties::builder().build());
7570
/// let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
@@ -101,9 +96,15 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke
10196
let derived_for = input.ident;
10297
let generics = input.generics;
10398

99+
let field_types: Vec<proc_macro2::TokenStream> =
100+
field_infos.iter().map(|x| x.parquet_type()).collect();
101+
104102
(quote! {
105103
impl#generics RecordWriter<#derived_for#generics> for &[#derived_for#generics] {
106-
fn write_to_row_group(&self, row_group_writer: &mut Box<parquet::file::writer::RowGroupWriter>) -> Result<(), parquet::errors::ParquetError> {
104+
fn write_to_row_group(
105+
&self,
106+
row_group_writer: &mut Box<parquet::file::writer::RowGroupWriter>
107+
) -> Result<(), parquet::errors::ParquetError> {
107108
let mut row_group_writer = row_group_writer;
108109
let records = &self; // Used by all the writer snippets to be more clear
109110

@@ -121,6 +122,22 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke
121122

122123
Ok(())
123124
}
125+
126+
fn schema(&self) -> Result<parquet::schema::types::TypePtr, parquet::errors::ParquetError> {
127+
use parquet::schema::types::Type as ParquetType;
128+
use parquet::schema::types::TypePtr;
129+
use parquet::basic::LogicalType;
130+
use parquet::basic::*;
131+
132+
let mut fields: Vec<TypePtr> = Vec::new();
133+
#(
134+
#field_types
135+
);*;
136+
let group = parquet::schema::types::Type::group_type_builder("rust_schema")
137+
.with_fields(&mut fields)
138+
.build()?;
139+
Ok(group.into())
140+
}
124141
}
125142
}).into()
126143
}

parquet_derive/src/parquet_field.rs

Lines changed: 151 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,50 @@ impl Field {
174174
}
175175
}
176176

177+
pub fn parquet_type(&self) -> proc_macro2::TokenStream {
178+
// TODO: Support group types
179+
// TODO: Add length if dealing with fixedlenbinary
180+
181+
let field_name = &self.ident.to_string();
182+
let physical_type = match self.ty.physical_type() {
183+
parquet::basic::Type::BOOLEAN => quote! {
184+
parquet::basic::Type::BOOLEAN
185+
},
186+
parquet::basic::Type::INT32 => quote! {
187+
parquet::basic::Type::INT32
188+
},
189+
parquet::basic::Type::INT64 => quote! {
190+
parquet::basic::Type::INT64
191+
},
192+
parquet::basic::Type::INT96 => quote! {
193+
parquet::basic::Type::INT96
194+
},
195+
parquet::basic::Type::FLOAT => quote! {
196+
parquet::basic::Type::FLOAT
197+
},
198+
parquet::basic::Type::DOUBLE => quote! {
199+
parquet::basic::Type::DOUBLE
200+
},
201+
parquet::basic::Type::BYTE_ARRAY => quote! {
202+
parquet::basic::Type::BYTE_ARRAY
203+
},
204+
parquet::basic::Type::FIXED_LEN_BYTE_ARRAY => quote! {
205+
parquet::basic::Type::FIXED_LEN_BYTE_ARRAY
206+
},
207+
};
208+
let logical_type = self.ty.logical_type();
209+
let repetition = self.ty.repetition();
210+
quote! {
211+
fields.push(ParquetType::primitive_type_builder(#field_name, #physical_type)
212+
.with_logical_type(#logical_type)
213+
.with_repetition(#repetition)
214+
.build()
215+
.unwrap()
216+
.into()
217+
);
218+
}
219+
}
220+
177221
fn option_into_vals(&self) -> proc_macro2::TokenStream {
178222
let field_name = &self.ident;
179223
let is_a_byte_buf = self.is_a_byte_buf;
@@ -201,7 +245,12 @@ impl Field {
201245
} else if is_a_byte_buf {
202246
quote! { Some((&inner[..]).into())}
203247
} else {
204-
quote! { Some(inner) }
248+
// Type might need converting to a physical type
249+
match self.ty.physical_type() {
250+
parquet::basic::Type::INT32 => quote! { Some(inner as i32) },
251+
parquet::basic::Type::INT64 => quote! { Some(inner as i64) },
252+
_ => quote! { Some(inner) },
253+
}
205254
};
206255

207256
quote! {
@@ -232,7 +281,12 @@ impl Field {
232281
} else if is_a_byte_buf {
233282
quote! { (&rec.#field_name[..]).into() }
234283
} else {
235-
quote! { rec.#field_name }
284+
// Type might need converting to a physical type
285+
match self.ty.physical_type() {
286+
parquet::basic::Type::INT32 => quote! { rec.#field_name as i32 },
287+
parquet::basic::Type::INT64 => quote! { rec.#field_name as i64 },
288+
_ => quote! { rec.#field_name },
289+
}
236290
};
237291

238292
quote! {
@@ -403,14 +457,98 @@ impl Type {
403457
"bool" => BasicType::BOOLEAN,
404458
"u8" | "u16" | "u32" => BasicType::INT32,
405459
"i8" | "i16" | "i32" | "NaiveDate" => BasicType::INT32,
406-
"u64" | "i64" | "usize" | "NaiveDateTime" => BasicType::INT64,
460+
"u64" | "i64" | "NaiveDateTime" => BasicType::INT64,
461+
"usize" | "isize" => {
462+
if usize::BITS == 64 {
463+
BasicType::INT64
464+
} else {
465+
BasicType::INT32
466+
}
467+
}
407468
"f32" => BasicType::FLOAT,
408469
"f64" => BasicType::DOUBLE,
409470
"String" | "str" | "Uuid" => BasicType::BYTE_ARRAY,
410471
f => unimplemented!("{} currently is not supported", f),
411472
}
412473
}
413474

475+
fn logical_type(&self) -> proc_macro2::TokenStream {
476+
let last_part = self.last_part();
477+
let leaf_type = self.leaf_type_recursive();
478+
479+
match leaf_type {
480+
Type::Array(ref first_type) => {
481+
if let Type::TypePath(_) = **first_type {
482+
if last_part == "u8" {
483+
return quote! { None };
484+
}
485+
}
486+
}
487+
Type::Vec(ref first_type) => {
488+
if let Type::TypePath(_) = **first_type {
489+
if last_part == "u8" {
490+
return quote! { None };
491+
}
492+
}
493+
}
494+
_ => (),
495+
}
496+
497+
match last_part.trim() {
498+
"bool" => quote! { None },
499+
"u8" => quote! { Some(LogicalType::INTEGER(IntType {
500+
bit_width: 8,
501+
is_signed: false,
502+
})) },
503+
"u16" => quote! { Some(LogicalType::INTEGER(IntType {
504+
bit_width: 16,
505+
is_signed: false,
506+
})) },
507+
"u32" => quote! { Some(LogicalType::INTEGER(IntType {
508+
bit_width: 32,
509+
is_signed: false,
510+
})) },
511+
"u64" => quote! { Some(LogicalType::INTEGER(IntType {
512+
bit_width: 64,
513+
is_signed: false,
514+
})) },
515+
"i8" => quote! { Some(LogicalType::INTEGER(IntType {
516+
bit_width: 8,
517+
is_signed: true,
518+
})) },
519+
"i16" => quote! { Some(LogicalType::INTEGER(IntType {
520+
bit_width: 16,
521+
is_signed: true,
522+
})) },
523+
"i32" | "i64" => quote! { None },
524+
"usize" => {
525+
quote! { Some(LogicalType::INTEGER(IntType {
526+
bit_width: usize::BITS as i8,
527+
is_signed: false
528+
})) }
529+
}
530+
"isize" => {
531+
quote! { Some(LogicalType::INTEGER(IntType {
532+
bit_width: usize::BITS as i8,
533+
is_signed: true
534+
})) }
535+
}
536+
"NaiveDate" => quote! { Some(LogicalType::DATE(Default::default())) },
537+
"f32" | "f64" => quote! { None },
538+
"String" | "str" => quote! { Some(LogicalType::STRING(Default::default())) },
539+
"Uuid" => quote! { Some(LogicalType::UUID(Default::default())) },
540+
f => unimplemented!("{} currently is not supported", f),
541+
}
542+
}
543+
544+
fn repetition(&self) -> proc_macro2::TokenStream {
545+
match &self {
546+
Type::Option(_) => quote! { Repetition::OPTIONAL },
547+
Type::Reference(_, ty) => ty.repetition(),
548+
_ => quote! { Repetition::REQUIRED },
549+
}
550+
}
551+
414552
/// Convert a parsed rust field AST in to a more easy to manipulate
415553
/// parquet_derive::Field
416554
fn from(f: &syn::Field) -> Self {
@@ -505,7 +643,7 @@ mod test {
505643
assert_eq!(snippet,
506644
(quote!{
507645
{
508-
let vals : Vec < _ > = records . iter ( ) . map ( | rec | rec . counter ) . collect ( );
646+
let vals : Vec < _ > = records . iter ( ) . map ( | rec | rec . counter as i64 ) . collect ( );
509647

510648
if let parquet::column::writer::ColumnWriter::Int64ColumnWriter ( ref mut typed ) = column_writer {
511649
typed . write_batch ( & vals [ .. ] , None , None ) ?;
@@ -585,7 +723,7 @@ mod test {
585723

586724
let vals: Vec <_> = records.iter().filter_map( |rec| {
587725
if let Some ( inner ) = rec . optional_dumb_int {
588-
Some ( inner )
726+
Some ( inner as i32 )
589727
} else {
590728
None
591729
}
@@ -636,12 +774,13 @@ mod test {
636774
struct ABasicStruct {
637775
yes_no: bool,
638776
name: String,
777+
length: usize
639778
}
640779
};
641780

642781
let fields = extract_fields(snippet);
643782
let processed: Vec<_> = fields.iter().map(|field| Field::from(field)).collect();
644-
assert_eq!(processed.len(), 2);
783+
assert_eq!(processed.len(), 3);
645784

646785
assert_eq!(
647786
processed,
@@ -657,6 +796,12 @@ mod test {
657796
ty: Type::TypePath(syn::parse_quote!(String)),
658797
is_a_byte_buf: true,
659798
third_party_type: None,
799+
},
800+
Field {
801+
ident: syn::Ident::new("length", proc_macro2::Span::call_site()),
802+
ty: Type::TypePath(syn::parse_quote!(usize)),
803+
is_a_byte_buf: false,
804+
third_party_type: None,
660805
}
661806
]
662807
)

0 commit comments

Comments
 (0)