diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index eb2cc967ca236..77c86bbbf7344 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -32,7 +32,9 @@ use crate::variation_const::{ TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; -use crate::variation_const::{FLOAT_16_TYPE_NAME, NULL_TYPE_NAME}; +use crate::variation_const::{ + FIXED_SIZE_LIST_TYPE_NAME, FLOAT_16_TYPE_NAME, NULL_TYPE_NAME, +}; use datafusion::arrow::datatypes::{ DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, }; @@ -256,6 +258,36 @@ pub fn from_substrait_type( } FLOAT_16_TYPE_NAME => Ok(DataType::Float16), NULL_TYPE_NAME => Ok(DataType::Null), + FIXED_SIZE_LIST_TYPE_NAME => { + if u.type_parameters.len() != 2 { + return substrait_err!( + "FixedSizeList requires 2 type parameters, got {}", + u.type_parameters.len() + ); + } + let inner_type = match &u.type_parameters[0].parameter { + Some(r#type::parameter::Parameter::DataType(t)) => { + from_substrait_type(consumer, t, dfs_names, name_idx)? + } + _ => { + return substrait_err!( + "Invalid inner type for FixedSizeList" + ); + } + }; + let size = match &u.type_parameters[1].parameter { + Some(r#type::parameter::Parameter::Integer(i)) => { + *i as i32 + } + _ => { + return substrait_err!( + "Invalid size for FixedSizeList" + ); + } + }; + let field = Arc::new(Field::new_list_field(inner_type, true)); + Ok(DataType::FixedSizeList(field, size)) + } _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 3727596119bc3..8a7fc2fcc2921 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -23,8 +23,8 @@ use crate::variation_const::{ DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, DEFAULT_MAP_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, DICTIONARY_MAP_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, - FLOAT_16_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, NULL_TYPE_NAME, - TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF, + FIXED_SIZE_LIST_TYPE_NAME, FLOAT_16_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, + NULL_TYPE_NAME, TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; use datafusion::arrow::datatypes::{DataType, IntervalUnit}; @@ -286,6 +286,31 @@ pub(crate) fn to_substrait_type( }))), }) } + DataType::FixedSizeList(inner, size) => { + let inner_type = + to_substrait_type(producer, inner.data_type(), inner.is_nullable())?; + let type_anchor = + producer.register_type(FIXED_SIZE_LIST_TYPE_NAME.to_string()); + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { + type_reference: type_anchor, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + type_parameters: vec![ + r#type::Parameter { + parameter: Some(r#type::parameter::Parameter::DataType( + inner_type, + )), + }, + r#type::Parameter { + parameter: Some(r#type::parameter::Parameter::Integer( + *size as i64, + )), + }, + ], + })), + }) + } DataType::Map(inner, _) => match inner.data_type() { DataType::Struct(key_and_value) if key_and_value.len() == 2 => { let key_type = to_substrait_type( @@ -439,6 +464,10 @@ mod tests { round_trip_type(DataType::LargeList( Field::new_list_field(DataType::Int32, true).into(), ))?; + round_trip_type(DataType::FixedSizeList( + Field::new_list_field(DataType::Int64, true).into(), + 10, + ))?; round_trip_type(DataType::Map( Field::new_struct( diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index b1a922899e976..b253ec7310859 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -130,3 +130,8 @@ pub const FLOAT_16_TYPE_NAME: &str = "fp16"; /// /// [`DataType::Null`]: datafusion::arrow::datatypes::DataType::Null pub const NULL_TYPE_NAME: &str = "null"; + +/// For [`DataType::FixedSizeList`] +/// +/// [`DataType::FixedSizeList`]: datafusion::arrow::datatypes::DataType::FixedSizeList +pub const FIXED_SIZE_LIST_TYPE_NAME: &str = "fixed_size_list";