Skip to content

Commit de32fc8

Browse files
authored
feat: Override MapBuilder values field with expected schema (#1643)
* feat: Override MapBuilder values field with expected schema * fmt
1 parent 71251ea commit de32fc8

File tree

3 files changed

+601
-752
lines changed

3 files changed

+601
-752
lines changed

native/core/src/execution/shuffle/map.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,10 +1925,10 @@ pub fn append_map_elements<K: ArrayBuilder, V: ArrayBuilder>(
19251925
map_builder: &mut MapBuilder<K, V>,
19261926
map: &SparkUnsafeMap,
19271927
) -> Result<(), CometError> {
1928-
let (key_dt, value_dt, _) = get_map_key_value_dt(field)?;
1928+
let (key_field, value_field, _) = get_map_key_value_fields(field)?;
19291929

19301930
// macro cannot expand to match arm
1931-
match (key_dt, value_dt) {
1931+
match (key_field.data_type(), value_field.data_type()) {
19321932
(DataType::Boolean, DataType::Boolean) => {
19331933
let map_builder =
19341934
downcast_builder_ref!(MapBuilder<BooleanBuilder, BooleanBuilder>, map_builder);
@@ -2823,7 +2823,8 @@ pub fn append_map_elements<K: ArrayBuilder, V: ArrayBuilder>(
28232823
_ => {
28242824
return Err(CometError::Internal(format!(
28252825
"Unsupported map key/value data type: {:?}/{:?}",
2826-
key_dt, value_dt
2826+
key_field.data_type(),
2827+
value_field.data_type()
28272828
)))
28282829
}
28292830
}
@@ -2832,13 +2833,13 @@ pub fn append_map_elements<K: ArrayBuilder, V: ArrayBuilder>(
28322833
}
28332834

28342835
#[allow(clippy::field_reassign_with_default)]
2835-
pub fn get_map_key_value_dt(
2836+
pub fn get_map_key_value_fields(
28362837
field: &FieldRef,
2837-
) -> Result<(&DataType, &DataType, MapFieldNames), CometError> {
2838+
) -> Result<(&FieldRef, &FieldRef, MapFieldNames), CometError> {
28382839
let mut map_fieldnames = MapFieldNames::default();
28392840
map_fieldnames.entry = field.name().to_string();
28402841

2841-
let (key_dt, value_dt) = match field.data_type() {
2842+
let (key_field, value_field) = match field.data_type() {
28422843
DataType::Struct(fields) => {
28432844
if fields.len() != 2 {
28442845
return Err(CometError::Internal(format!(
@@ -2847,12 +2848,13 @@ pub fn get_map_key_value_dt(
28472848
)));
28482849
}
28492850

2850-
map_fieldnames.key = fields[0].name().to_string();
2851-
map_fieldnames.value = fields[1].name().to_string();
2851+
let key = &fields[0];
2852+
let value = &fields[1];
28522853

2853-
let key_dt = fields[0].data_type();
2854-
let value_dt = fields[1].data_type();
2855-
(key_dt, value_dt)
2854+
map_fieldnames.key = key.name().to_string();
2855+
map_fieldnames.value = value.name().to_string();
2856+
2857+
(key, value)
28562858
}
28572859
_ => {
28582860
return Err(CometError::Internal(format!(
@@ -2862,5 +2864,5 @@ pub fn get_map_key_value_dt(
28622864
}
28632865
};
28642866

2865-
Ok((key_dt, value_dt, map_fieldnames))
2867+
Ok((key_field, value_field, map_fieldnames))
28662868
}

0 commit comments

Comments
 (0)