1414
1515use std:: collections:: BTreeMap ;
1616use std:: collections:: HashMap ;
17+ use std:: collections:: HashSet ;
1718use std:: string:: String ;
1819use std:: sync:: Arc ;
1920
@@ -32,6 +33,7 @@ use databend_common_expression::types::StringType;
3233use databend_common_expression:: types:: ValueType ;
3334use databend_common_expression:: with_integer_mapped_type;
3435use databend_common_expression:: BlockEntry ;
36+ use databend_common_expression:: Column ;
3537use databend_common_expression:: ColumnBuilder ;
3638use databend_common_expression:: DataBlock ;
3739use databend_common_expression:: Scalar ;
@@ -53,6 +55,92 @@ use crate::sql::plans::DictGetFunctionArgument;
5355use crate :: sql:: plans:: DictionarySource ;
5456use crate :: sql:: IndexType ;
5557
58+ macro_rules! sqlx_fetch_optional {
59+ ( $pool: expr, $sql: expr, $key_type: ty, $val_type: ty, $format_val_fn: expr) => { {
60+ let res: Option <( $key_type, $val_type) > =
61+ sqlx:: query_as( & $sql) . fetch_optional( $pool) . await ?;
62+ Ok ( res. map( |( _, v) | $format_val_fn( v) ) )
63+ } } ;
64+ }
65+
66+ macro_rules! fetch_single_row_by_sqlx {
67+ ( $pool: expr, $sql: expr, $key_scalar: expr, $val_type: ty, $format_val_fn: expr) => { {
68+ match $key_scalar {
69+ DataType :: Boolean => {
70+ sqlx_fetch_optional!( $pool, $sql, bool , $val_type, $format_val_fn)
71+ }
72+ DataType :: String => {
73+ sqlx_fetch_optional!( $pool, $sql, String , $val_type, $format_val_fn)
74+ }
75+ DataType :: Number ( num_ty) => with_integer_mapped_type!( |KEY_NUM_TYPE | match num_ty {
76+ NumberDataType :: KEY_NUM_TYPE => {
77+ sqlx_fetch_optional!( $pool, $sql, KEY_NUM_TYPE , $val_type, $format_val_fn)
78+ }
79+ NumberDataType :: Float32 => {
80+ sqlx_fetch_optional!( $pool, $sql, f32 , $val_type, $format_val_fn)
81+ }
82+ NumberDataType :: Float64 => {
83+ sqlx_fetch_optional!( $pool, $sql, f64 , $val_type, $format_val_fn)
84+ }
85+ } ) ,
86+ _ => Err ( ErrorCode :: DictionarySourceError ( format!(
87+ "MySQL dictionary operator currently does not support value type {}" ,
88+ $key_scalar,
89+ ) ) ) ,
90+ }
91+ } } ;
92+ }
93+
94+ macro_rules! fetch_all_rows_by_sqlx {
95+ ( $pool: expr, $sql: expr, $key_scalar: expr, $val_type: ty, $format_key_fn: expr) => {
96+ match $key_scalar {
97+ DataType :: Boolean => {
98+ let res: Vec <( bool , $val_type) > = sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
99+ res. into_iter( )
100+ . map( |( k, v) | ( $format_key_fn( ScalarRef :: Boolean ( k) ) , v) )
101+ . collect( )
102+ }
103+ DataType :: String => {
104+ let res: Vec <( String , $val_type) > = sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
105+ res. into_iter( )
106+ . map( |( k, v) | ( $format_key_fn( ScalarRef :: String ( & k) ) , v) )
107+ . collect( )
108+ }
109+ DataType :: Number ( num_ty) => {
110+ with_integer_mapped_type!( |NUM_TYPE | match num_ty {
111+ NumberDataType :: NUM_TYPE => {
112+ let res: Vec <( NUM_TYPE , $val_type) > =
113+ sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
114+ res. into_iter( )
115+ . map( |( k, v) | ( format!( "{}" , k) , v) )
116+ . collect( )
117+ }
118+ NumberDataType :: Float32 => {
119+ let res: Vec <( f32 , $val_type) > =
120+ sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
121+ res. into_iter( )
122+ . map( |( k, v) | ( format!( "{}" , k) , v) )
123+ . collect( )
124+ }
125+ NumberDataType :: Float64 => {
126+ let res: Vec <( f64 , $val_type) > =
127+ sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
128+ res. into_iter( )
129+ . map( |( k, v) | ( format!( "{}" , k) , v) )
130+ . collect( )
131+ }
132+ } )
133+ }
134+ _ => {
135+ return Err ( ErrorCode :: DictionarySourceError ( format!(
136+ "MySQL dictionary operator currently does not support value type: {}" ,
137+ $key_scalar
138+ ) ) ) ;
139+ }
140+ }
141+ } ;
142+ }
143+
56144pub ( crate ) enum DictionaryOperator {
57145 Redis ( ConnectionManager ) ,
58146 Mysql ( ( MySqlPool , String ) ) ,
@@ -95,21 +183,14 @@ impl DictionaryOperator {
95183 DictionaryOperator :: Mysql ( ( pool, sql) ) => match value {
96184 Value :: Scalar ( scalar) => {
97185 let value = self
98- . get_data_from_mysql ( scalar. as_ref ( ) , data_type, pool, sql)
186+ . get_scalar_value_from_mysql ( scalar. as_ref ( ) , data_type, pool, sql)
99187 . await ?
100188 . unwrap_or ( default_value. clone ( ) ) ;
101189 Ok ( Value :: Scalar ( value) )
102190 }
103191 Value :: Column ( column) => {
104- let mut builder = ColumnBuilder :: with_capacity ( data_type, column. len ( ) ) ;
105- for scalar_ref in column. iter ( ) {
106- let value = self
107- . get_data_from_mysql ( scalar_ref, data_type, pool, sql)
108- . await ?
109- . unwrap_or ( default_value. clone ( ) ) ;
110- builder. push ( value. as_ref ( ) ) ;
111- }
112- Ok ( Value :: Column ( builder. build ( ) ) )
192+ self . get_column_values_from_mysql ( column, data_type, default_value, pool, sql)
193+ . await
113194 }
114195 } ,
115196 }
@@ -239,72 +320,174 @@ impl DictionaryOperator {
239320 }
240321 }
241322
242- async fn get_data_from_mysql (
323+ async fn get_scalar_value_from_mysql (
243324 & self ,
244325 key : ScalarRef < ' _ > ,
245- data_type : & DataType ,
326+ value_type : & DataType ,
246327 pool : & MySqlPool ,
247328 sql : & String ,
248329 ) -> Result < Option < Scalar > > {
249330 if key == ScalarRef :: Null {
250331 return Ok ( None ) ;
251332 }
252- match data_type. remove_nullable ( ) {
333+ let new_sql = format ! ( "{} ({}) LIMIT 1" , sql, self . format_key( key. clone( ) ) ) ;
334+ let key_type = key. infer_data_type ( ) . remove_nullable ( ) ;
335+ match value_type. remove_nullable ( ) {
253336 DataType :: Boolean => {
254- let value: Option < bool > = sqlx:: query_scalar ( sql)
255- . bind ( self . format_key ( key) )
256- . fetch_optional ( pool)
257- . await ?;
258- Ok ( value. map ( Scalar :: Boolean ) )
337+ fetch_single_row_by_sqlx ! ( pool, new_sql, key_type, bool , Scalar :: Boolean )
259338 }
260339 DataType :: String => {
261- let value: Option < String > = sqlx:: query_scalar ( sql)
262- . bind ( self . format_key ( key) )
263- . fetch_optional ( pool)
264- . await ?;
265- Ok ( value. map ( Scalar :: String ) )
340+ fetch_single_row_by_sqlx ! ( pool, new_sql, key_type, String , Scalar :: String )
266341 }
267342 DataType :: Number ( num_ty) => {
268343 with_integer_mapped_type ! ( |NUM_TYPE | match num_ty {
269344 NumberDataType :: NUM_TYPE => {
270- let value: Option <NUM_TYPE > = sqlx:: query_scalar( & sql)
271- . bind( self . format_key( key) )
272- . fetch_optional( pool)
273- . await ?;
274- Ok ( value. map( |v| Scalar :: Number ( NUM_TYPE :: upcast_scalar( v) ) ) )
345+ fetch_single_row_by_sqlx!( pool, new_sql, key_type, NUM_TYPE , |v| {
346+ Scalar :: Number ( NUM_TYPE :: upcast_scalar( v) )
347+ } )
275348 }
276349 NumberDataType :: Float32 => {
277- let value: Option <f32 > = sqlx:: query_scalar( sql)
278- . bind( self . format_key( key) )
279- . fetch_optional( pool)
280- . await ?;
281- Ok ( value. map( |v| Scalar :: Number ( NumberScalar :: Float32 ( v. into( ) ) ) ) )
350+ fetch_single_row_by_sqlx!( pool, new_sql, key_type, f32 , |v: f32 | {
351+ Scalar :: Number ( NumberScalar :: Float32 ( v. into( ) ) )
352+ } )
282353 }
283354 NumberDataType :: Float64 => {
284- let value: Option <f64 > = sqlx:: query_scalar( sql)
285- . bind( self . format_key( key) )
286- . fetch_optional( pool)
287- . await ?;
288- Ok ( value. map( |v| Scalar :: Number ( NumberScalar :: Float64 ( v. into( ) ) ) ) )
355+ fetch_single_row_by_sqlx!( pool, new_sql, key_type, f64 , |v: f64 | {
356+ Scalar :: Number ( NumberScalar :: Float64 ( v. into( ) ) )
357+ } )
289358 }
290359 } )
291360 }
292361 _ => Err ( ErrorCode :: DictionarySourceError ( format ! (
293- "MySQL dictionary operator currently does not support value type {data_type }"
362+ "MySQL dictionary operator currently does not support value type {value_type }"
294363 ) ) ) ,
295364 }
296365 }
297366
367+ async fn get_column_values_from_mysql (
368+ & self ,
369+ column : & Column ,
370+ value_type : & DataType ,
371+ default_value : & Scalar ,
372+ pool : & MySqlPool ,
373+ sql : & String ,
374+ ) -> Result < Value < AnyType > > {
375+ // todo: The current method formats the key as a string, which causes some performance overhead.
376+ // The next step is to use the key's native types directly, such as bool, i32, etc.
377+ let key_cnt = column. len ( ) ;
378+ let mut all_keys = Vec :: with_capacity ( key_cnt) ;
379+ let mut key_set = HashSet :: with_capacity ( key_cnt) ;
380+ for item in column. iter ( ) {
381+ if item != ScalarRef :: Null {
382+ key_set. insert ( item. clone ( ) ) ;
383+ }
384+ all_keys. push ( self . format_key ( item) ) ;
385+ }
386+
387+ let mut builder = ColumnBuilder :: with_capacity ( value_type, key_cnt) ;
388+ if key_set. is_empty ( ) {
389+ for _ in 0 ..key_cnt {
390+ builder. push ( default_value. as_ref ( ) ) ;
391+ }
392+ return Ok ( Value :: Column ( builder. build ( ) ) ) ;
393+ }
394+ let new_sql = format ! ( "{} ({})" , sql, self . format_keys( key_set) ) ;
395+ let key_type = column. data_type ( ) . remove_nullable ( ) ;
396+ match value_type. remove_nullable ( ) {
397+ DataType :: Boolean => {
398+ let kv_pairs: HashMap < String , bool > =
399+ fetch_all_rows_by_sqlx ! ( pool, & new_sql, key_type, bool , |k| self . format_key( k) ) ;
400+ for key in all_keys {
401+ match kv_pairs. get ( & key) {
402+ Some ( v) => builder. push ( Scalar :: Boolean ( * v) . as_ref ( ) ) ,
403+ None => builder. push ( default_value. as_ref ( ) ) ,
404+ }
405+ }
406+ }
407+ DataType :: String => {
408+ let kv_pairs: HashMap < String , String > =
409+ fetch_all_rows_by_sqlx ! ( pool, & new_sql, key_type, String , |k| self
410+ . format_key( k) ) ;
411+ for key in all_keys {
412+ match kv_pairs. get ( & key) {
413+ Some ( v) => builder. push ( Scalar :: String ( v. to_string ( ) ) . as_ref ( ) ) ,
414+ None => builder. push ( default_value. as_ref ( ) ) ,
415+ }
416+ }
417+ }
418+ DataType :: Number ( num_ty) => {
419+ with_integer_mapped_type ! ( |NUM_TYPE | match num_ty {
420+ NumberDataType :: NUM_TYPE => {
421+ let kv_pairs: HashMap <String , NUM_TYPE > =
422+ fetch_all_rows_by_sqlx!( pool, & new_sql, key_type, NUM_TYPE , |k| self
423+ . format_key( k) ) ;
424+ for key in all_keys {
425+ match kv_pairs. get( & key) {
426+ Some ( v) => builder
427+ . push( Scalar :: Number ( NUM_TYPE :: upcast_scalar( * v) ) . as_ref( ) ) ,
428+ None => builder. push( default_value. as_ref( ) ) ,
429+ }
430+ }
431+ }
432+ NumberDataType :: Float32 => {
433+ let kv_pairs: HashMap <String , f32 > =
434+ fetch_all_rows_by_sqlx!( pool, & new_sql, key_type, f32 , |k| self
435+ . format_key( k) ) ;
436+ for key in all_keys {
437+ match kv_pairs. get( & key) {
438+ Some ( v) => builder. push(
439+ Scalar :: Number ( NumberScalar :: Float32 ( ( * v) . into( ) ) ) . as_ref( ) ,
440+ ) ,
441+ None => builder. push( default_value. as_ref( ) ) ,
442+ }
443+ }
444+ }
445+ NumberDataType :: Float64 => {
446+ let kv_pairs: HashMap <String , f64 > =
447+ fetch_all_rows_by_sqlx!( pool, & new_sql, key_type, f64 , |k| self
448+ . format_key( k) ) ;
449+ for key in all_keys {
450+ match kv_pairs. get( & key) {
451+ Some ( v) => builder. push(
452+ Scalar :: Number ( NumberScalar :: Float64 ( ( * v) . into( ) ) ) . as_ref( ) ,
453+ ) ,
454+ None => builder. push( default_value. as_ref( ) ) ,
455+ }
456+ }
457+ }
458+ } )
459+ }
460+ _ => {
461+ return Err ( ErrorCode :: DictionarySourceError ( format ! (
462+ "MySQL dictionary operator currently does not support value type {value_type}"
463+ ) ) ) ;
464+ }
465+ }
466+ Ok ( Value :: Column ( builder. build ( ) ) )
467+ }
468+
469+ #[ inline]
298470 fn format_key ( & self , key : ScalarRef < ' _ > ) -> String {
299471 match key {
300- ScalarRef :: String ( s) => s . to_string ( ) ,
472+ ScalarRef :: String ( s) => format ! ( "'{}'" , s . replace ( "'" , " \\ '" ) ) ,
301473 ScalarRef :: Date ( d) => format ! ( "{}" , date_to_string( d as i64 , & TimeZone :: UTC ) ) ,
302474 ScalarRef :: Timestamp ( t) => {
303475 format ! ( "{}" , timestamp_to_string( t, & TimeZone :: UTC ) )
304476 }
305477 _ => format ! ( "{}" , key) ,
306478 }
307479 }
480+
481+ #[ inline]
482+ fn format_keys ( & self , keys : HashSet < ScalarRef > ) -> String {
483+ format ! (
484+ "{}" ,
485+ keys. into_iter( )
486+ . map( |key| self . format_key( key) )
487+ . collect:: <Vec <String >>( )
488+ . join( "," )
489+ )
490+ }
308491}
309492
310493impl TransformAsyncFunction {
@@ -339,8 +522,11 @@ impl TransformAsyncFunction {
339522 sqlx:: MySqlPool :: connect ( & sql_source. connection_url ) ,
340523 ) ?;
341524 let sql = format ! (
342- "SELECT {} FROM {} WHERE {} = ? LIMIT 1" ,
343- & sql_source. value_field, & sql_source. table, & sql_source. key_field
525+ "SELECT {}, {} FROM {} WHERE {} in" ,
526+ & sql_source. key_field,
527+ & sql_source. value_field,
528+ & sql_source. table,
529+ & sql_source. key_field
344530 ) ;
345531 operators. insert ( i, Arc :: new ( DictionaryOperator :: Mysql ( ( mysql_pool, sql) ) ) ) ;
346532 }
0 commit comments