@@ -26,7 +26,10 @@ use arrow::compute::{can_cast_types, CastOptions};
2626use arrow:: datatypes:: { DataType , DataType :: * , FieldRef , Schema } ;
2727use arrow:: record_batch:: RecordBatch ;
2828use datafusion_common:: format:: DEFAULT_FORMAT_OPTIONS ;
29- use datafusion_common:: { not_impl_err, Result } ;
29+ use datafusion_common:: {
30+ nested_struct:: { cast_column, validate_struct_compatibility} ,
31+ not_impl_err, Result , ScalarValue ,
32+ } ;
3033use datafusion_expr_common:: columnar_value:: ColumnarValue ;
3134use datafusion_expr_common:: interval_arithmetic:: Interval ;
3235use datafusion_expr_common:: sort_properties:: ExprProperties ;
@@ -41,6 +44,18 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
4144 format_options : DEFAULT_FORMAT_OPTIONS ,
4245} ;
4346
47+ fn has_positional_fields ( fields : & [ FieldRef ] ) -> bool {
48+ fields. iter ( ) . enumerate ( ) . any ( |( idx, f) | {
49+ f. name ( ) . is_empty ( )
50+ || f. name ( )
51+ . as_str ( )
52+ . strip_prefix ( 'c' )
53+ . and_then ( |suffix| suffix. parse :: < usize > ( ) . ok ( ) )
54+ . map ( |n| n == idx)
55+ . unwrap_or ( false )
56+ } )
57+ }
58+
4459/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
4560#[ derive( Debug , Clone , Eq ) ]
4661pub struct CastExpr {
@@ -138,12 +153,48 @@ impl PhysicalExpr for CastExpr {
138153 }
139154
140155 fn nullable ( & self , input_schema : & Schema ) -> Result < bool > {
141- self . expr . nullable ( input_schema)
156+ if matches ! ( self . cast_type, Struct ( _) ) {
157+ Ok ( self . return_field ( input_schema) ?. is_nullable ( ) )
158+ } else {
159+ self . expr . nullable ( input_schema)
160+ }
142161 }
143162
144163 fn evaluate ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
145164 let value = self . expr . evaluate ( batch) ?;
146- value. cast_to ( & self . cast_type , Some ( & self . cast_options ) )
165+ let Struct ( target_fields) = & self . cast_type else {
166+ return value. cast_to ( & self . cast_type , Some ( & self . cast_options ) ) ;
167+ } ;
168+ let Struct ( source_fields) = self . expr . data_type ( batch. schema ( ) . as_ref ( ) ) ? else {
169+ return value. cast_to ( & self . cast_type , Some ( & self . cast_options ) ) ;
170+ } ;
171+
172+ let use_struct_cast = target_fields. len ( ) > source_fields. len ( )
173+ || has_positional_fields ( & source_fields)
174+ || has_positional_fields ( target_fields)
175+ || target_fields
176+ . iter ( )
177+ . any ( |t| source_fields. iter ( ) . all ( |s| s. name ( ) != t. name ( ) ) ) ;
178+
179+ if !use_struct_cast || source_fields == * target_fields {
180+ return value. cast_to ( & self . cast_type , Some ( & self . cast_options ) ) ;
181+ }
182+
183+ let target_field = self . return_field ( batch. schema ( ) . as_ref ( ) ) ?;
184+ match value {
185+ ColumnarValue :: Array ( array) => {
186+ let casted =
187+ cast_column ( & array, target_field. as_ref ( ) , & self . cast_options ) ?;
188+ Ok ( ColumnarValue :: Array ( casted) )
189+ }
190+ ColumnarValue :: Scalar ( scalar) => {
191+ let as_array = scalar. to_array_of_size ( 1 ) ?;
192+ let casted =
193+ cast_column ( & as_array, target_field. as_ref ( ) , & self . cast_options ) ?;
194+ let result = ScalarValue :: try_from_array ( casted. as_ref ( ) , 0 ) ?;
195+ Ok ( ColumnarValue :: Scalar ( result) )
196+ }
197+ }
147198 }
148199
149200 fn return_field ( & self , input_schema : & Schema ) -> Result < FieldRef > {
@@ -229,6 +280,13 @@ pub fn cast_with_options(
229280 let expr_type = expr. data_type ( input_schema) ?;
230281 if expr_type == cast_type {
231282 Ok ( Arc :: clone ( & expr) )
283+ } else if let Struct ( target_fields) = & cast_type {
284+ if let Struct ( source_fields) = expr_type {
285+ validate_struct_compatibility ( & source_fields, target_fields) ?;
286+ } else if expr_type != Null {
287+ return not_impl_err ! ( "Unsupported CAST from {expr_type} to {cast_type}" ) ;
288+ }
289+ Ok ( Arc :: new ( CastExpr :: new ( expr, cast_type, cast_options) ) )
232290 } else if can_cast_types ( & expr_type, & cast_type) {
233291 Ok ( Arc :: new ( CastExpr :: new ( expr, cast_type, cast_options) ) )
234292 } else {
@@ -252,16 +310,17 @@ pub fn cast(
252310mod tests {
253311 use super :: * ;
254312
255- use crate :: expressions:: column:: col;
313+ use crate :: expressions:: { column:: col, Column , Literal } ;
256314
257315 use arrow:: {
258316 array:: {
259- Array , Decimal128Array , Float32Array , Float64Array , Int16Array , Int32Array ,
260- Int64Array , Int8Array , StringArray , Time64NanosecondArray ,
261- TimestampNanosecondArray , UInt32Array ,
317+ Array , ArrayRef , BooleanArray , Decimal128Array , Float32Array , Float64Array ,
318+ Int16Array , Int32Array , Int64Array , Int8Array , StringArray , StructArray ,
319+ Time64NanosecondArray , TimestampNanosecondArray , UInt32Array ,
262320 } ,
263321 datatypes:: * ,
264322 } ;
323+ use datafusion_common:: cast:: { as_int64_array, as_string_array, as_uint8_array} ;
265324 use datafusion_physical_expr_common:: physical_expr:: fmt_sql;
266325 use insta:: assert_snapshot;
267326
@@ -809,4 +868,153 @@ mod tests {
809868
810869 Ok ( ( ) )
811870 }
871+
872+ fn make_schema ( field : & Field ) -> SchemaRef {
873+ Arc :: new ( Schema :: new ( vec ! [ field. clone( ) ] ) )
874+ }
875+
876+ fn make_struct_array ( fields : Fields , arrays : Vec < ArrayRef > ) -> StructArray {
877+ StructArray :: new ( fields, arrays, None )
878+ }
879+
880+ /// Casts one struct array to another with different fields.
881+ fn cast_struct_array (
882+ input : StructArray ,
883+ target_type : & DataType ,
884+ ) -> Result < StructArray > {
885+ let batch = RecordBatch :: try_from_iter ( vec ! [ ( "s" , Arc :: new( input) as ArrayRef ) ] ) ?;
886+ let column = Arc :: new ( Column :: new_with_schema ( "s" , batch. schema ( ) . as_ref ( ) ) ?) ;
887+ let expr = CastExpr :: new ( column, target_type. clone ( ) , Some ( DEFAULT_CAST_OPTIONS ) ) ;
888+
889+ let result = expr. evaluate ( & batch) ?;
890+ let ColumnarValue :: Array ( array) = result else {
891+ panic ! ( "expected array" ) ;
892+ } ;
893+ let struct_array = array
894+ . as_any ( )
895+ . downcast_ref :: < StructArray > ( )
896+ . expect ( "struct array" ) ;
897+ Ok ( struct_array. clone ( ) )
898+ }
899+
900+ /// Ensures struct casts fill missing target fields with nulls and reorder correctly.
901+ /// Input: { "a": [1, null], "b": ["alpha", "beta"] }
902+ /// Output: { "a": [1, null], "c": [null, null] }
903+ #[ test]
904+ fn cast_struct_array_missing_child ( ) -> Result < ( ) > {
905+ let source_a = Arc :: new ( Field :: new ( "a" , Int32 , true ) ) ;
906+ let source_b = Arc :: new ( Field :: new ( "b" , Utf8 , true ) ) ;
907+
908+ let struct_array = make_struct_array (
909+ vec ! [ source_a, source_b] . into ( ) ,
910+ vec ! [
911+ Arc :: new( Int32Array :: from( vec![ Some ( 1 ) , None ] ) ) as ArrayRef ,
912+ Arc :: new( StringArray :: from( vec![ Some ( "alpha" ) , Some ( "beta" ) ] ) )
913+ as ArrayRef ,
914+ ] ,
915+ ) ;
916+
917+ let target_a = Arc :: new ( Field :: new ( "a" , Int64 , true ) ) ;
918+ let target_c = Arc :: new ( Field :: new ( "c" , Utf8 , true ) ) ;
919+ let target_type = Struct ( Fields :: from ( vec ! [ target_a, target_c] ) ) ;
920+
921+ let output_array = cast_struct_array ( struct_array, & target_type) ?;
922+
923+ let cast_a = as_int64_array ( output_array. column_by_name ( "a" ) . unwrap ( ) . as_ref ( ) ) ?;
924+ assert_eq ! ( cast_a. value( 0 ) , 1 ) ;
925+ assert ! ( cast_a. is_null( 1 ) ) ;
926+
927+ let cast_c = as_string_array ( output_array. column_by_name ( "c" ) . unwrap ( ) . as_ref ( ) ) ?;
928+ assert ! ( cast_c. is_null( 0 ) ) ;
929+ assert ! ( cast_c. is_null( 1 ) ) ;
930+ Ok ( ( ) )
931+ }
932+
933+ /// Verifies nested struct casts recurse through multiple levels preserving
934+ /// values and adding null placeholders.
935+ ///
936+ /// Input: { "root": { "inner": { "x": [7, null] } } }
937+ /// Output: { "root": { "inner": { "x": [7, null], "y": [null, null] } } }
938+ #[ test]
939+ fn cast_nested_struct_array ( ) -> Result < ( ) > {
940+ let inner_source_fields = Fields :: from ( [ Arc :: new ( Field :: new ( "x" , Int32 , true ) ) ] ) ;
941+
942+ let inner_source = Field :: new_struct ( "inner" , inner_source_fields. clone ( ) , true ) ;
943+
944+ let inner_target_fields: Fields = vec ! [
945+ Arc :: new( Field :: new( "x" , Int64 , true ) ) ,
946+ Arc :: new( Field :: new( "y" , Boolean , true ) ) ,
947+ ]
948+ . into ( ) ;
949+ let inner_target = Field :: new ( "inner" , Struct ( inner_target_fields. clone ( ) ) , true ) ;
950+ let target_type = Struct ( vec ! [ Arc :: new( inner_target. clone( ) ) ] . into ( ) ) ;
951+
952+ let inner_struct = make_struct_array (
953+ inner_source_fields. clone ( ) ,
954+ vec ! [ Arc :: new( Int32Array :: from( vec![ Some ( 7 ) , None ] ) ) as ArrayRef ] ,
955+ ) ;
956+ let outer_struct = make_struct_array (
957+ vec ! [ Arc :: new( inner_source. clone( ) ) ] . into ( ) ,
958+ vec ! [ Arc :: new( inner_struct) as ArrayRef ] ,
959+ ) ;
960+ let output_array = cast_struct_array ( outer_struct, & target_type) ?;
961+
962+ let inner = output_array
963+ . column_by_name ( "inner" )
964+ . unwrap ( )
965+ . as_any ( )
966+ . downcast_ref :: < StructArray > ( )
967+ . expect ( "inner struct" ) ;
968+ let x = as_int64_array ( inner. column_by_name ( "x" ) . unwrap ( ) . as_ref ( ) ) ?;
969+ assert_eq ! ( x. value( 0 ) , 7 ) ;
970+ assert ! ( x. is_null( 1 ) ) ;
971+ let y = inner. column_by_name ( "y" ) . unwrap ( ) ;
972+ let y = y
973+ . as_any ( )
974+ . downcast_ref :: < BooleanArray > ( )
975+ . expect ( "boolean array" ) ;
976+ assert ! ( y. is_null( 0 ) ) ;
977+ assert ! ( y. is_null( 1 ) ) ;
978+ Ok ( ( ) )
979+ }
980+
981+ #[ test]
982+ // Confirms struct casting works for scalars by casting through array form and back to ScalarValue.
983+ fn cast_struct_scalar ( ) -> Result < ( ) > {
984+ let source_field = Field :: new ( "a" , Int32 , true ) ;
985+ let input_field = Field :: new (
986+ "s" ,
987+ Struct ( vec ! [ Arc :: new( source_field. clone( ) ) ] . into ( ) ) ,
988+ true ,
989+ ) ;
990+ let target_field = Field :: new (
991+ "s" ,
992+ Struct ( vec ! [ Arc :: new( Field :: new( "a" , UInt8 , true ) ) ] . into ( ) ) ,
993+ true ,
994+ ) ;
995+
996+ let schema = make_schema ( & input_field) ;
997+ let scalar_struct = StructArray :: new (
998+ vec ! [ Arc :: new( source_field. clone( ) ) ] . into ( ) ,
999+ vec ! [ Arc :: new( Int32Array :: from( vec![ Some ( 9 ) ] ) ) as ArrayRef ] ,
1000+ None ,
1001+ ) ;
1002+ let literal =
1003+ Arc :: new ( Literal :: new ( ScalarValue :: Struct ( Arc :: new ( scalar_struct) ) ) ) ;
1004+ let expr = CastExpr :: new (
1005+ literal,
1006+ target_field. data_type ( ) . clone ( ) ,
1007+ Some ( DEFAULT_CAST_OPTIONS ) ,
1008+ ) ;
1009+
1010+ let batch = RecordBatch :: new_empty ( Arc :: clone ( & schema) ) ;
1011+ let result = expr. evaluate ( & batch) ?;
1012+ let ColumnarValue :: Scalar ( ScalarValue :: Struct ( array) ) = result else {
1013+ panic ! ( "expected struct scalar" ) ;
1014+ } ;
1015+ let casted = array. column_by_name ( "a" ) . unwrap ( ) ;
1016+ let casted = as_uint8_array ( casted. as_ref ( ) ) ?;
1017+ assert_eq ! ( casted. value( 0 ) , 9 ) ;
1018+ Ok ( ( ) )
1019+ }
8121020}
0 commit comments