Skip to content

Commit 4fa8232

Browse files
authored
chore: refactor Substrait consumer's "rename_field" and implement the rest of types (#16345)
* refactor Substrait consumer's "rename_field" and implement the rest of types The rename_field is a bit confusing with it's "rename_self" parameter. Also there are times when one wants to just rename a data type. And I'd like to reuse the same code within our codebase, since with Substrait the case of renaming a field/type comes up a bunch, so I'd like to make these functions "pub". And lastly, this adds support for all list types, as well as dict/ree/union types. Dunno how necessary those are, but seems right to support them still. * datatype => data_type * fix clippy
1 parent 0e84041 commit 4fa8232

File tree

2 files changed

+181
-92
lines changed

2 files changed

+181
-92
lines changed

datafusion/substrait/src/logical_plan/consumer/expr/mod.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ pub use subquery::*;
3838
pub use window_function::*;
3939

4040
use crate::extensions::Extensions;
41-
use crate::logical_plan::consumer::utils::rename_field;
4241
use crate::logical_plan::consumer::{
43-
from_substrait_named_struct, DefaultSubstraitConsumer, SubstraitConsumer,
42+
from_substrait_named_struct, rename_field, DefaultSubstraitConsumer,
43+
SubstraitConsumer,
4444
};
4545
use datafusion::arrow::datatypes::Field;
4646
use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, DFSchemaRef};
@@ -152,7 +152,6 @@ pub async fn from_substrait_extended_expr(
152152
&substrait_expr.output_names,
153153
expr_idx,
154154
&mut names_idx,
155-
/*rename_self=*/ true,
156155
)?;
157156
exprs.push((expr, output_field));
158157
}

datafusion/substrait/src/logical_plan/consumer/utils.rs

Lines changed: 179 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
// under the License.
1717

1818
use crate::logical_plan::consumer::SubstraitConsumer;
19-
use datafusion::arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema};
19+
use datafusion::arrow::datatypes::{DataType, Field, Schema, UnionFields};
2020
use datafusion::common::{
21-
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
22-
TableReference,
21+
exec_err, not_impl_err, substrait_datafusion_err, substrait_err, DFSchema,
22+
DFSchemaRef, TableReference,
2323
};
2424
use datafusion::logical_expr::expr::Sort;
2525
use datafusion::logical_expr::{Cast, Expr, ExprSchemable, LogicalPlanBuilder};
@@ -81,98 +81,167 @@ pub(super) fn next_struct_field_name(
8181
}
8282
}
8383

84-
pub(super) fn rename_field(
84+
/// Traverse through the field, renaming the provided field itself and all its inner struct fields.
85+
pub fn rename_field(
8586
field: &Field,
8687
dfs_names: &Vec<String>,
8788
unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}"
8889
name_idx: &mut usize, // Index into dfs_names
89-
rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name
9090
) -> datafusion::common::Result<Field> {
91-
let name = if rename_self {
92-
next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)?
93-
} else {
94-
field.name().to_string()
95-
};
96-
match field.data_type() {
91+
let name = next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)?;
92+
rename_fields_data_type(field.clone().with_name(name), dfs_names, name_idx)
93+
}
94+
95+
/// Rename the field's data type but not the field itself.
96+
pub fn rename_fields_data_type(
97+
field: Field,
98+
dfs_names: &Vec<String>,
99+
name_idx: &mut usize, // Index into dfs_names
100+
) -> datafusion::common::Result<Field> {
101+
let dt = rename_data_type(field.data_type(), dfs_names, name_idx)?;
102+
Ok(field.with_data_type(dt))
103+
}
104+
105+
/// Traverse through the data type (incl. lists/maps/etc), renaming all inner struct fields.
106+
pub fn rename_data_type(
107+
data_type: &DataType,
108+
dfs_names: &Vec<String>,
109+
name_idx: &mut usize, // Index into dfs_names
110+
) -> datafusion::common::Result<DataType> {
111+
match data_type {
97112
DataType::Struct(children) => {
98113
let children = children
99114
.iter()
100115
.enumerate()
101-
.map(|(child_idx, f)| {
102-
rename_field(
103-
f.as_ref(),
104-
dfs_names,
105-
child_idx,
106-
name_idx,
107-
/*rename_self=*/ true,
108-
)
116+
.map(|(field_idx, f)| {
117+
rename_field(f.as_ref(), dfs_names, field_idx, name_idx)
109118
})
110119
.collect::<datafusion::common::Result<_>>()?;
111-
Ok(field
112-
.to_owned()
113-
.with_name(name)
114-
.with_data_type(DataType::Struct(children)))
120+
Ok(DataType::Struct(children))
115121
}
116-
DataType::List(inner) => {
117-
let renamed_inner = rename_field(
118-
inner.as_ref(),
122+
DataType::List(inner) => Ok(DataType::List(Arc::new(rename_fields_data_type(
123+
inner.as_ref().to_owned(),
124+
dfs_names,
125+
name_idx,
126+
)?))),
127+
DataType::LargeList(inner) => Ok(DataType::LargeList(Arc::new(
128+
rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?,
129+
))),
130+
DataType::ListView(inner) => Ok(DataType::ListView(Arc::new(
131+
rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?,
132+
))),
133+
DataType::LargeListView(inner) => Ok(DataType::LargeListView(Arc::new(
134+
rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?,
135+
))),
136+
DataType::FixedSizeList(inner, len) => Ok(DataType::FixedSizeList(
137+
Arc::new(rename_fields_data_type(
138+
inner.as_ref().to_owned(),
119139
dfs_names,
120-
0,
121140
name_idx,
122-
/*rename_self=*/ false,
123-
)?;
124-
Ok(field
125-
.to_owned()
126-
.with_data_type(DataType::List(FieldRef::new(renamed_inner)))
127-
.with_name(name))
141+
)?),
142+
*len,
143+
)),
144+
DataType::Map(entries, sorted) => {
145+
let entries_data_type = match entries.data_type() {
146+
DataType::Struct(fields) => {
147+
// This should be two fields, normally "key" and "value", but not guaranteed
148+
let fields = fields
149+
.iter()
150+
.map(|f| {
151+
rename_fields_data_type(
152+
f.as_ref().to_owned(),
153+
dfs_names,
154+
name_idx,
155+
)
156+
})
157+
.collect::<datafusion::common::Result<_>>()?;
158+
Ok(DataType::Struct(fields))
159+
}
160+
_ => exec_err!("Expected map type to contain an inner struct type"),
161+
}?;
162+
Ok(DataType::Map(
163+
Arc::new(
164+
entries
165+
.as_ref()
166+
.to_owned()
167+
.with_data_type(entries_data_type),
168+
),
169+
*sorted,
170+
))
128171
}
129-
DataType::LargeList(inner) => {
130-
let renamed_inner = rename_field(
131-
inner.as_ref(),
172+
DataType::Dictionary(key_type, value_type) => {
173+
// Dicts probably shouldn't contain structs, but support them just in case one does
174+
Ok(DataType::Dictionary(
175+
Box::new(rename_data_type(key_type, dfs_names, name_idx)?),
176+
Box::new(rename_data_type(value_type, dfs_names, name_idx)?),
177+
))
178+
}
179+
DataType::RunEndEncoded(run_ends_field, values_field) => {
180+
// At least the run_ends_field shouldn't contain names (since it should be i16/i32/i64),
181+
// but we'll try renaming its datatype just in case.
182+
let run_ends_field = rename_fields_data_type(
183+
run_ends_field.as_ref().clone(),
184+
dfs_names,
185+
name_idx,
186+
)?;
187+
let values_field = rename_fields_data_type(
188+
values_field.as_ref().clone(),
132189
dfs_names,
133-
0,
134190
name_idx,
135-
/*rename_self= */ false,
136191
)?;
137-
Ok(field
138-
.to_owned()
139-
.with_data_type(DataType::LargeList(FieldRef::new(renamed_inner)))
140-
.with_name(name))
192+
193+
Ok(DataType::RunEndEncoded(
194+
Arc::new(run_ends_field),
195+
Arc::new(values_field),
196+
))
141197
}
142-
DataType::Map(inner, sorted) => match inner.data_type() {
143-
DataType::Struct(key_and_value) if key_and_value.len() == 2 => {
144-
let renamed_keys = rename_field(
145-
key_and_value[0].as_ref(),
146-
dfs_names,
147-
0,
148-
name_idx,
149-
/*rename_self=*/ false,
150-
)?;
151-
let renamed_values = rename_field(
152-
key_and_value[1].as_ref(),
153-
dfs_names,
154-
0,
155-
name_idx,
156-
/*rename_self=*/ false,
157-
)?;
158-
Ok(field
159-
.to_owned()
160-
.with_data_type(DataType::Map(
161-
Arc::new(Field::new(
162-
inner.name(),
163-
DataType::Struct(Fields::from(vec![
164-
renamed_keys,
165-
renamed_values,
166-
])),
167-
inner.is_nullable(),
168-
)),
169-
*sorted,
198+
DataType::Union(fields, mode) => {
199+
let fields = fields
200+
.iter()
201+
.map(|(i, f)| {
202+
Ok((
203+
i,
204+
Arc::new(rename_fields_data_type(
205+
f.as_ref().clone(),
206+
dfs_names,
207+
name_idx,
208+
)?),
170209
))
171-
.with_name(name))
172-
}
173-
_ => substrait_err!("Map fields must contain a Struct with exactly 2 fields"),
174-
},
175-
_ => Ok(field.to_owned().with_name(name)),
210+
})
211+
.collect::<datafusion::common::Result<UnionFields>>()?;
212+
Ok(DataType::Union(fields, *mode))
213+
}
214+
// Explicitly listing the rest (which can not contain inner fields needing renaming)
215+
// to ensure we're exhaustive
216+
DataType::Null
217+
| DataType::Boolean
218+
| DataType::Int8
219+
| DataType::Int16
220+
| DataType::Int32
221+
| DataType::Int64
222+
| DataType::UInt8
223+
| DataType::UInt16
224+
| DataType::UInt32
225+
| DataType::UInt64
226+
| DataType::Float16
227+
| DataType::Float32
228+
| DataType::Float64
229+
| DataType::Timestamp(_, _)
230+
| DataType::Date32
231+
| DataType::Date64
232+
| DataType::Time32(_)
233+
| DataType::Time64(_)
234+
| DataType::Duration(_)
235+
| DataType::Interval(_)
236+
| DataType::Binary
237+
| DataType::FixedSizeBinary(_)
238+
| DataType::LargeBinary
239+
| DataType::BinaryView
240+
| DataType::Utf8
241+
| DataType::LargeUtf8
242+
| DataType::Utf8View
243+
| DataType::Decimal128(_, _)
244+
| DataType::Decimal256(_, _) => Ok(data_type.clone()),
176245
}
177246
}
178247

@@ -190,13 +259,8 @@ pub(super) fn make_renamed_schema(
190259
.iter()
191260
.enumerate()
192261
.map(|(field_idx, (q, f))| {
193-
let renamed_f = rename_field(
194-
f.as_ref(),
195-
dfs_names,
196-
field_idx,
197-
&mut name_idx,
198-
/*rename_self=*/ true,
199-
)?;
262+
let renamed_f =
263+
rename_field(f.as_ref(), dfs_names, field_idx, &mut name_idx)?;
200264
Ok((q.cloned(), renamed_f))
201265
})
202266
.collect::<datafusion::common::Result<Vec<_>>>()?
@@ -473,17 +537,29 @@ pub(crate) mod tests {
473537
),
474538
(
475539
Some(table_ref.clone()),
476-
Arc::new(Field::new_map(
540+
Arc::new(Field::new_large_list(
477541
"7",
542+
Arc::new(Field::new_struct(
543+
"item",
544+
vec![Field::new("8", DataType::Int32, false)],
545+
false,
546+
)),
547+
false,
548+
)),
549+
),
550+
(
551+
Some(table_ref.clone()),
552+
Arc::new(Field::new_map(
553+
"9",
478554
"entries",
479555
Arc::new(Field::new_struct(
480556
"keys",
481-
vec![Field::new("8", DataType::Int32, false)],
557+
vec![Field::new("10", DataType::Int32, false)],
482558
false,
483559
)),
484560
Arc::new(Field::new_struct(
485561
"values",
486-
vec![Field::new("9", DataType::Int32, false)],
562+
vec![Field::new("11", DataType::Int32, false)],
487563
false,
488564
)),
489565
false,
@@ -504,10 +580,12 @@ pub(crate) mod tests {
504580
"h".to_string(),
505581
"i".to_string(),
506582
"j".to_string(),
583+
"k".to_string(),
584+
"l".to_string(),
507585
];
508586
let renamed_schema = make_renamed_schema(&schema, &dfs_names)?;
509587

510-
assert_eq!(renamed_schema.fields().len(), 4);
588+
assert_eq!(renamed_schema.fields().len(), 5);
511589
assert_eq!(
512590
*renamed_schema.field(0),
513591
Field::new("a", DataType::Int32, false)
@@ -541,17 +619,29 @@ pub(crate) mod tests {
541619
);
542620
assert_eq!(
543621
*renamed_schema.field(3),
544-
Field::new_map(
622+
Field::new_large_list(
545623
"h",
624+
Arc::new(Field::new_struct(
625+
"item",
626+
vec![Field::new("i", DataType::Int32, false)],
627+
false,
628+
)),
629+
false,
630+
)
631+
);
632+
assert_eq!(
633+
*renamed_schema.field(4),
634+
Field::new_map(
635+
"j",
546636
"entries",
547637
Arc::new(Field::new_struct(
548638
"keys",
549-
vec![Field::new("i", DataType::Int32, false)],
639+
vec![Field::new("k", DataType::Int32, false)],
550640
false,
551641
)),
552642
Arc::new(Field::new_struct(
553643
"values",
554-
vec![Field::new("j", DataType::Int32, false)],
644+
vec![Field::new("l", DataType::Int32, false)],
555645
false,
556646
)),
557647
false,

0 commit comments

Comments
 (0)