@@ -76,6 +76,9 @@ use crate::arrow::datatypes::TimeUnit;
7676use crate :: execution:: context:: TaskContext ;
7777use crate :: physical_plan:: coalesce_batches:: concat_batches;
7878use crate :: physical_plan:: PhysicalExpr ;
79+ use datafusion_expr:: binary_rule:: coerce_types;
80+ use datafusion_expr:: Operator ;
81+ use datafusion_physical_expr:: expressions:: try_cast;
7982use log:: debug;
8083use std:: fmt;
8184
@@ -295,7 +298,32 @@ impl ExecutionPlan for HashJoinExec {
295298 partition : usize ,
296299 context : Arc < TaskContext > ,
297300 ) -> Result < SendableRecordBatchStream > {
298- let on_left = self . on . iter ( ) . map ( |on| on. 0 . clone ( ) ) . collect :: < Vec < _ > > ( ) ;
301+ // This is a hacky way to support type coercion for join expressions
302+ // Without this it would panic later, in build_join_indexes => equal_rows, when it would try to downcast both sides to same primitive type
303+ // TODO Remove this after rebasing on top of commit ac2e5d15 "Support type coercion for equijoin (#4666)". It was first released at DF 16.0
304+
305+ // TODO Rewrite it with iterators on modern toolchain, `impl FromIterator<(AE, BE)> for (A, B)` is not available ATM
306+ let mut on_left = Vec :: with_capacity ( self . on . len ( ) ) ;
307+ let mut on_right = Vec :: with_capacity ( self . on . len ( ) ) ;
308+ for on in & self . on {
309+ let l = Arc :: new ( on. 0 . clone ( ) ) ;
310+ let r = Arc :: new ( on. 1 . clone ( ) ) ;
311+
312+ let lt = l. data_type ( & self . left . schema ( ) ) ?;
313+ let rt = r. data_type ( & self . right . schema ( ) ) ?;
314+ let res_type = coerce_types ( & lt, & Operator :: Eq , & rt) ?;
315+
316+ let left_cast = try_cast ( l, & self . left . schema ( ) , res_type. clone ( ) ) ?;
317+ let right_cast = try_cast ( r, & self . right . schema ( ) , res_type) ?;
318+
319+ on_left. push ( left_cast) ;
320+ on_right. push ( right_cast) ;
321+ }
322+
323+ // Make them immutable
324+ let on_left = on_left;
325+ let on_right = on_right;
326+
299327 // we only want to compute the build side once for PartitionMode::CollectLeft
300328 let left_data = {
301329 match self . mode {
@@ -414,7 +442,6 @@ impl ExecutionPlan for HashJoinExec {
414442 // over the right that uses this information to issue new batches.
415443
416444 let right_stream = self . right . execute ( partition, context. clone ( ) ) . await ?;
417- let on_right = self . on . iter ( ) . map ( |on| on. 1 . clone ( ) ) . collect :: < Vec < _ > > ( ) ;
418445
419446 let num_rows = left_data. 1 . num_rows ( ) ;
420447 let visited_left_side = match self . join_type {
@@ -473,7 +500,7 @@ impl ExecutionPlan for HashJoinExec {
473500/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
474501/// assuming that the [RecordBatch] corresponds to the `index`th
475502fn update_hash (
476- on : & [ Column ] ,
503+ on : & [ Arc < dyn PhysicalExpr > ] ,
477504 batch : & RecordBatch ,
478505 hash_map : & mut JoinHashMap ,
479506 offset : usize ,
@@ -512,9 +539,9 @@ struct HashJoinStream {
512539 /// Input schema
513540 schema : Arc < Schema > ,
514541 /// columns from the left
515- on_left : Vec < Column > ,
542+ on_left : Vec < Arc < dyn PhysicalExpr > > ,
516543 /// columns from the right used to compute the hash
517- on_right : Vec < Column > ,
544+ on_right : Vec < Arc < dyn PhysicalExpr > > ,
518545 /// type of the join
519546 join_type : JoinType ,
520547 /// information from the left
@@ -539,8 +566,8 @@ struct HashJoinStream {
539566impl HashJoinStream {
540567 fn new (
541568 schema : Arc < Schema > ,
542- on_left : Vec < Column > ,
543- on_right : Vec < Column > ,
569+ on_left : Vec < Arc < dyn PhysicalExpr > > ,
570+ on_right : Vec < Arc < dyn PhysicalExpr > > ,
544571 join_type : JoinType ,
545572 left_data : JoinLeftData ,
546573 right : SendableRecordBatchStream ,
@@ -624,8 +651,8 @@ fn build_batch_from_indices(
624651fn build_batch (
625652 batch : & RecordBatch ,
626653 left_data : & JoinLeftData ,
627- on_left : & [ Column ] ,
628- on_right : & [ Column ] ,
654+ on_left : & [ Arc < dyn PhysicalExpr > ] ,
655+ on_right : & [ Arc < dyn PhysicalExpr > ] ,
629656 join_type : JoinType ,
630657 schema : & Schema ,
631658 column_indices : & [ ColumnIndex ] ,
@@ -691,8 +718,8 @@ fn build_join_indexes(
691718 left_data : & JoinLeftData ,
692719 right : & RecordBatch ,
693720 join_type : JoinType ,
694- left_on : & [ Column ] ,
695- right_on : & [ Column ] ,
721+ left_on : & [ Arc < dyn PhysicalExpr > ] ,
722+ right_on : & [ Arc < dyn PhysicalExpr > ] ,
696723 random_state : & RandomState ,
697724 null_equals_null : & bool ,
698725) -> Result < ( UInt64Array , UInt32Array ) > {
@@ -2002,8 +2029,8 @@ mod tests {
20022029 & left_data,
20032030 & right,
20042031 JoinType :: Inner ,
2005- & [ Column :: new ( "a" , 0 ) ] ,
2006- & [ Column :: new ( "a" , 0 ) ] ,
2032+ & [ Arc :: new ( Column :: new ( "a" , 0 ) ) ] ,
2033+ & [ Arc :: new ( Column :: new ( "a" , 0 ) ) ] ,
20072034 & random_state,
20082035 & false ,
20092036 ) ?;
0 commit comments