Skip to content

Commit 9fd6490

Browse files
committed
fix tests
1 parent e419c35 commit 9fd6490

File tree

1 file changed

+26
-13
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+26
-13
lines changed

datafusion/physical-expr/src/expressions/cast.rs

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
4444
format_options: DEFAULT_FORMAT_OPTIONS,
4545
};
4646

47+
fn has_positional_fields(fields: &[FieldRef]) -> bool {
48+
fields.iter().enumerate().any(|(idx, f)| {
49+
f.name().is_empty()
50+
|| f.name()
51+
.as_str()
52+
.strip_prefix('c')
53+
.and_then(|suffix| suffix.parse::<usize>().ok())
54+
.map(|n| n == idx)
55+
.unwrap_or(false)
56+
})
57+
}
58+
4759
/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
4860
#[derive(Debug, Clone, Eq)]
4961
pub struct CastExpr {
@@ -150,34 +162,35 @@ impl PhysicalExpr for CastExpr {
150162

151163
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
152164
let value = self.expr.evaluate(batch)?;
153-
// delegate to cast_to, except for struct-to-struct casts
154165
let Struct(target_fields) = &self.cast_type else {
155166
return value.cast_to(&self.cast_type, Some(&self.cast_options));
156167
};
157168
let Struct(source_fields) = self.expr.data_type(batch.schema().as_ref())? else {
158169
return value.cast_to(&self.cast_type, Some(&self.cast_options));
159170
};
160-
if &source_fields == target_fields {
171+
172+
let use_struct_cast = target_fields.len() > source_fields.len()
173+
|| has_positional_fields(&source_fields)
174+
|| has_positional_fields(target_fields)
175+
|| target_fields
176+
.iter()
177+
.any(|t| source_fields.iter().all(|s| s.name() != t.name()));
178+
179+
if !use_struct_cast || source_fields == *target_fields {
161180
return value.cast_to(&self.cast_type, Some(&self.cast_options));
162-
};
181+
}
163182

164183
let target_field = self.return_field(batch.schema().as_ref())?;
165184
match value {
166185
ColumnarValue::Array(array) => {
167-
let casted = cast_column(
168-
&array,
169-
target_field.as_ref(),
170-
&self.cast_options,
171-
)?;
186+
let casted =
187+
cast_column(&array, target_field.as_ref(), &self.cast_options)?;
172188
Ok(ColumnarValue::Array(casted))
173189
}
174190
ColumnarValue::Scalar(scalar) => {
175191
let as_array = scalar.to_array_of_size(1)?;
176-
let casted = cast_column(
177-
&as_array,
178-
target_field.as_ref(),
179-
&self.cast_options,
180-
)?;
192+
let casted =
193+
cast_column(&as_array, target_field.as_ref(), &self.cast_options)?;
181194
let result = ScalarValue::try_from_array(casted.as_ref(), 0)?;
182195
Ok(ColumnarValue::Scalar(result))
183196
}

0 commit comments

Comments
 (0)