@@ -1004,6 +1004,11 @@ typedef struct cnonce {
10041004#define TYPE_LEN 64
10051005#define POLICY_LEN 24
10061006
1007+ struct cdb2_stmt_types {
1008+ int n ;
1009+ int types [0 ];
1010+ };
1011+
10071012struct cdb2_hndl {
10081013 char dbname [DBNAME_LEN ];
10091014 char cluster [64 ];
@@ -1102,6 +1107,7 @@ struct cdb2_hndl {
11021107 struct cdb2_hndl * fdb_hndl ;
11031108 int is_child_hndl ;
11041109 CDB2SQLQUERY__IdentityBlob * id_blob ;
1110+ struct cdb2_stmt_types * stmt_types ;
11051111};
11061112
11071113static void * cdb2_protobuf_alloc (void * allocator_data , size_t size )
@@ -3766,6 +3772,11 @@ int cdb2_close(cdb2_hndl_tp *hndl)
37663772 if (!hndl )
37673773 return 0 ;
37683774
3775+ if (hndl -> stmt_types ) {
3776+ free (hndl -> stmt_types );
3777+ hndl -> stmt_types = NULL ;
3778+ }
3779+
37693780 if (hndl -> fdb_hndl ) {
37703781 cdb2_close (hndl -> fdb_hndl );
37713782 hndl -> fdb_hndl = NULL ;
@@ -4399,6 +4410,89 @@ static inline void clear_snapshot_info(cdb2_hndl_tp *hndl, int line)
43994410 hndl -> is_retry = 0 ;
44004411}
44014412
4413+ static const struct {
4414+ const char * name ;
4415+ size_t name_sz ;
4416+ cdb2_coltype type ;
4417+ } all_types [] = {
4418+ {"INTEGER" , sizeof ("INTEGER" ) - 1 , CDB2_INTEGER },
4419+ {"CSTRING" , sizeof ("CSTRING" ) - 1 , CDB2_CSTRING },
4420+ {"REAL" , sizeof ("REAL" ) - 1 , CDB2_REAL },
4421+ {"BLOB" , sizeof ("BLOB" ) - 1 , CDB2_BLOB },
4422+ {"DATETIME" , sizeof ("DATETIME" ) - 1 , CDB2_DATETIME },
4423+ {"DATETIMEUS" , sizeof ("DATETIMEUS" ) - 1 , CDB2_DATETIMEUS },
4424+ {"INTERVALDS" , sizeof ("INTERVALDS" ) - 1 , CDB2_INTERVALDS },
4425+ {"INTERVALDSUS" , sizeof ("INTERVALDSUS" ) - 1 , CDB2_INTERVALDSUS },
4426+ {"INTERVALYM" , sizeof ("INTERVALYM" ) - 1 , CDB2_INTERVALYM }
4427+ };
4428+
4429+ static const int total_types = sizeof (all_types ) / sizeof (all_types [0 ]);
4430+
4431+ #define get_toklen (tok ) ({ \
4432+ cdb2_skipws(tok); \
4433+ int len = 0; \
4434+ while (tok[len] && !isspace(tok[len])) ++len; \
4435+ len; \
4436+ })
4437+
4438+ static int process_set_stmt_return_types (cdb2_hndl_tp * hndl , const char * sql )
4439+ {
4440+ int toklen ;
4441+ const char * tok = sql + 3 ; /* if we're here, first token is "set" */
4442+
4443+ toklen = get_toklen (tok );
4444+ if (toklen != 9 || strncasecmp (tok , "statement" , 9 ) != 0 ) return -1 ;
4445+ tok += toklen ;
4446+
4447+ toklen = get_toklen (tok );
4448+ if (toklen != 6 || strncasecmp (tok , "return" , 6 ) != 0 ) return -1 ;
4449+ tok += toklen ;
4450+
4451+ toklen = get_toklen (tok );
4452+ if (toklen != 5 || strncasecmp (tok , "types" , 5 ) != 0 ) return -1 ;
4453+ tok += toklen ;
4454+
4455+ if (hndl -> stmt_types ) {
4456+ sprintf (hndl -> errstr , "%s: already have %d parameter(s)" , __func__ , hndl -> stmt_types -> n );
4457+ return 1 ;
4458+ }
4459+
4460+ const int max_args = 1024 ;
4461+ uint8_t types [max_args ];
4462+ int count = 0 ;
4463+
4464+ while (1 ) {
4465+ toklen = get_toklen (tok );
4466+ if (toklen == 0 ) break ;
4467+ if (count == max_args ) {
4468+ sprintf (hndl -> errstr , "%s: max number of columns:%d" , __func__ , max_args );
4469+ return 1 ;
4470+ }
4471+ int i ;
4472+ for (i = 0 ; i < total_types ; ++ i ) {
4473+ if (toklen == all_types [i ].name_sz && strncasecmp (tok , all_types [i ].name , toklen ) == 0 ) {
4474+ tok += toklen ;
4475+ types [count ++ ] = all_types [i ].type ;
4476+ break ;
4477+ }
4478+ }
4479+ if (i >= total_types ) {
4480+ snprintf (hndl -> errstr , sizeof (hndl -> errstr ), "%s: column:%d has bad type:'%.*s'" , __func__ , count , toklen , tok );
4481+ return 1 ;
4482+ }
4483+ }
4484+ if (count == 0 ) {
4485+ sprintf (hndl -> errstr , "%s: bad number of columns:%d" , __func__ , count );
4486+ return 1 ;
4487+ }
4488+ hndl -> stmt_types = malloc (sizeof (struct cdb2_stmt_types ) + sizeof (int ) * count );
4489+ hndl -> stmt_types -> n = count ;
4490+ for (int i = 0 ; i < count ; ++ i ) {
4491+ hndl -> stmt_types -> types [i ] = types [i ];
4492+ }
4493+ return 0 ;
4494+ }
4495+
44024496static int process_set_command (cdb2_hndl_tp * hndl , const char * sql )
44034497{
44044498 int i , j , k ;
@@ -4411,9 +4505,9 @@ static int process_set_command(cdb2_hndl_tp *hndl, const char *sql)
44114505 return CDB2ERR_BADREQ ;
44124506 }
44134507
4414- int rc = process_ssl_set_command ( hndl , sql ) ;
4415- if (rc >= 0 )
4416- return rc ;
4508+ int rc ;
4509+ if (( rc = process_ssl_set_command ( hndl , sql )) >= 0 ) return rc ;
4510+ if (( rc = process_set_stmt_return_types ( hndl , sql )) >= 0 ) return rc ;
44174511
44184512 i = hndl -> num_set_commands ;
44194513 if (i > 0 ) {
@@ -4574,7 +4668,7 @@ static void attach_to_handle(cdb2_hndl_tp *child, cdb2_hndl_tp *parent)
45744668}
45754669
45764670static int cdb2_run_statement_typed_int (cdb2_hndl_tp * hndl , const char * sql ,
4577- int ntypes , int * types , int line )
4671+ int ntypes , int * types , int line , int * set_stmt )
45784672{
45794673 int return_value ;
45804674 int using_hint = 0 ;
@@ -4597,9 +4691,20 @@ static int cdb2_run_statement_typed_int(cdb2_hndl_tp *hndl, const char *sql,
45974691
45984692 /* sniff out 'set hasql on' here */
45994693 if (strncasecmp (sql , "set" , 3 ) == 0 ) {
4694+ * set_stmt = 1 ;
46004695 return process_set_command (hndl , sql );
46014696 }
46024697
4698+ if (hndl -> stmt_types ) {
4699+ if (ntypes || types ) {
4700+ sprintf (hndl -> errstr , "%s: provided %d type(s), but already have %d" ,
4701+ __func__ , ntypes , hndl -> stmt_types -> n );
4702+ return -1 ;
4703+ }
4704+ ntypes = hndl -> stmt_types -> n ;
4705+ types = hndl -> stmt_types -> types ;
4706+ }
4707+
46034708 if (strncasecmp (sql , "begin" , 5 ) == 0 ) {
46044709 debugprint ("setting is_begin flag\n" );
46054710 is_begin = 1 ;
@@ -5262,7 +5367,7 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
52625367{
52635368 int rc = 0 ;
52645369
5265- void * callbackrc ;
5370+ int set_stmt = 0 ;
52665371 int overwrite_rc = 0 ;
52675372 cdb2_event * e = NULL ;
52685373
@@ -5280,15 +5385,21 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
52805385
52815386 while ((e = cdb2_next_callback (hndl , CDB2_AT_ENTER_RUN_STATEMENT , e )) !=
52825387 NULL ) {
5283- callbackrc = cdb2_invoke_callback (hndl , e , 1 , CDB2_SQL , sql );
5388+ void * callbackrc = cdb2_invoke_callback (hndl , e , 1 , CDB2_SQL , sql );
52845389 PROCESS_EVENT_CTRL_BEFORE (hndl , e , rc , callbackrc , overwrite_rc );
52855390 }
52865391
5287- if (overwrite_rc )
5392+ if (overwrite_rc ) {
5393+ const char * first = sql ;
5394+ int len = get_toklen (first );
5395+ if (len == 3 && strncasecmp (first , "set" , 3 ) == 0 ) {
5396+ set_stmt = 1 ;
5397+ }
52885398 goto after_callback ;
5399+ }
52895400
52905401 if (hndl -> temp_trans && hndl -> in_trans ) {
5291- cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ );
5402+ cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ , & set_stmt );
52925403 }
52935404
52945405 hndl -> temp_trans = 0 ;
@@ -5297,7 +5408,7 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
52975408 (strncasecmp (sql , "set" , 3 ) != 0 && strncasecmp (sql , "begin" , 5 ) != 0 &&
52985409 strncasecmp (sql , "commit" , 6 ) != 0 &&
52995410 strncasecmp (sql , "rollback" , 8 ) != 0 )) {
5300- rc = cdb2_run_statement_typed_int (hndl , "begin" , 0 , NULL , __LINE__ );
5411+ rc = cdb2_run_statement_typed_int (hndl , "begin" , 0 , NULL , __LINE__ , & set_stmt );
53015412 if (rc ) {
53025413 debugprint ("cdb2_run_statement_typed_int rc = %d\n" , rc );
53035414 goto after_callback ;
@@ -5306,7 +5417,7 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
53065417 }
53075418
53085419 cdb2_skipws (sql );
5309- rc = cdb2_run_statement_typed_int (hndl , sql , ntypes , types , __LINE__ );
5420+ rc = cdb2_run_statement_typed_int (hndl , sql , ntypes , types , __LINE__ , & set_stmt );
53105421 if (rc )
53115422 debugprint ("rc = %d\n" , rc );
53125423
@@ -5315,38 +5426,51 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
53155426 if (hndl -> temp_trans && !is_sql_read (sql )) {
53165427 if (rc == 0 ) {
53175428 int commit_rc =
5318- cdb2_run_statement_typed_int (hndl , "commit" , 0 , NULL , __LINE__ );
5429+ cdb2_run_statement_typed_int (hndl , "commit" , 0 , NULL , __LINE__ , & set_stmt );
53195430 debugprint ("rc = %d\n" , commit_rc );
53205431 rc = commit_rc ;
53215432 } else {
5322- cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ );
5433+ cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ , & set_stmt );
53235434 }
53245435 hndl -> temp_trans = 0 ;
53255436 }
53265437
53275438 if (log_calls ) {
5328- if (ntypes == 0 )
5439+ if (set_stmt || ( ntypes == 0 && hndl -> stmt_types == NULL ) )
53295440 fprintf (stderr , "%p> cdb2_run_statement(%p, \"%s\") = %d\n" ,
53305441 (void * )pthread_self (), hndl , sql , rc );
5331- else {
5442+ else if ( ntypes ) {
53325443 fprintf (stderr , "%p> cdb2_run_statement_typed(%p, \"%s\", [" ,
53335444 (void * )pthread_self (), hndl , sql );
53345445 for (int i = 0 ; i < ntypes ; i ++ ) {
53355446 fprintf (stderr , "%s%s" , cdb2_type_str (types [i ]),
53365447 i == ntypes - 1 ? "" : ", " );
53375448 }
53385449 fprintf (stderr , "] = %d\n" , rc );
5450+ } else {
5451+ int n = hndl -> stmt_types -> n ;
5452+ int * t = hndl -> stmt_types -> types ;
5453+ fprintf (stderr , "%p> cdb2_run_statement_typed(%p, \"%s\", [" , (void * )pthread_self (), hndl , sql );
5454+ for (int i = 0 ; i < n ; ++ i ) {
5455+ fprintf (stderr , "%s%s" , cdb2_type_str (t [i ]), i == n - 1 ? "" : ", " );
5456+ }
5457+ fprintf (stderr , "] = %d\n" , rc );
53395458 }
53405459 }
53415460
53425461after_callback :
53435462 while ((e = cdb2_next_callback (hndl , CDB2_AT_EXIT_RUN_STATEMENT , e )) !=
53445463 NULL ) {
5345- callbackrc = cdb2_invoke_callback (hndl , e , 2 , CDB2_SQL , sql ,
5464+ void * callbackrc = cdb2_invoke_callback (hndl , e , 2 , CDB2_SQL , sql ,
53465465 CDB2_RETURN_VALUE , (intptr_t )rc );
53475466 PROCESS_EVENT_CTRL_AFTER (hndl , e , rc , callbackrc );
53485467 }
53495468
5469+ if (hndl -> stmt_types && !set_stmt ) {
5470+ free (hndl -> stmt_types );
5471+ hndl -> stmt_types = NULL ;
5472+ }
5473+
53505474 return rc ;
53515475}
53525476
0 commit comments