@@ -109,7 +109,12 @@ impl CaseExpr {
109109
110110macro_rules! if_then_else {
111111 ( $BUILDER_TYPE: ty, $ARRAY_TYPE: ty, $BOOLS: expr, $TRUE: expr, $FALSE: expr) => { {
112- let true_values = $TRUE
112+ let true_values = if $TRUE. data_type( ) == & DataType :: Null {
113+ Arc :: new( <$ARRAY_TYPE>:: from( vec![ None ; $TRUE. len( ) ] ) )
114+ } else {
115+ $TRUE
116+ } ;
117+ let true_values = true_values
113118 . as_ref( )
114119 . as_any( )
115120 . downcast_ref:: <$ARRAY_TYPE>( )
@@ -118,7 +123,12 @@ macro_rules! if_then_else {
118123 stringify!( $ARRAY_TYPE)
119124 ) ) ;
120125
121- let false_values = $FALSE
126+ let false_values = if $FALSE. data_type( ) == & DataType :: Null {
127+ Arc :: new( <$ARRAY_TYPE>:: from( vec![ None ; $FALSE. len( ) ] ) )
128+ } else {
129+ $FALSE
130+ } ;
131+ let false_values = false_values
122132 . as_ref( )
123133 . as_any( )
124134 . downcast_ref:: <$ARRAY_TYPE>( )
@@ -252,6 +262,23 @@ fn if_then_else(
252262}
253263
254264impl CaseExpr {
265+ /// This function returns the return type of CASE expression.
266+ ///
267+ /// The first non-Null THEN expr type is returned; if there are none, ELSE type is returned.
268+ /// In the abscense of ELSE, Null is returned.
269+ fn return_type ( & self , schema : & Schema ) -> Result < DataType > {
270+ for ( _, then) in self . when_then_expr . iter ( ) {
271+ match then. data_type ( schema) ? {
272+ DataType :: Null => continue ,
273+ dt => return Ok ( dt) ,
274+ } ;
275+ }
276+ if let Some ( else_expr) = & self . else_expr {
277+ return else_expr. data_type ( schema) ;
278+ }
279+ Ok ( DataType :: Null )
280+ }
281+
255282 /// This function evaluates the form of CASE that matches an expression to fixed values.
256283 ///
257284 /// CASE expression
@@ -260,7 +287,7 @@ impl CaseExpr {
260287 /// [ELSE result]
261288 /// END
262289 fn case_when_with_expr ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
263- let return_type = self . when_then_expr [ 0 ] . 1 . data_type ( & batch. schema ( ) ) ?;
290+ let return_type = self . return_type ( & batch. schema ( ) ) ?;
264291 let expr = self . expr . as_ref ( ) . unwrap ( ) ;
265292 let base_value = expr. evaluate ( batch) ?;
266293 let base_value = base_value. into_array ( batch. num_rows ( ) ) ;
@@ -339,7 +366,7 @@ impl CaseExpr {
339366 /// [ELSE result]
340367 /// END
341368 fn case_when_no_expr ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
342- let return_type = self . when_then_expr [ 0 ] . 1 . data_type ( & batch. schema ( ) ) ?;
369+ let return_type = self . return_type ( & batch. schema ( ) ) ?;
343370
344371 // start with nulls as default output
345372 let mut current_value = new_null_array ( & return_type, batch. num_rows ( ) ) ;
@@ -392,7 +419,7 @@ impl PhysicalExpr for CaseExpr {
392419 }
393420
394421 fn data_type ( & self , input_schema : & Schema ) -> Result < DataType > {
395- self . when_then_expr [ 0 ] . 1 . data_type ( input_schema)
422+ self . return_type ( input_schema)
396423 }
397424
398425 fn nullable ( & self , input_schema : & Schema ) -> Result < bool > {
0 commit comments