@@ -3608,6 +3608,11 @@ int cdb2_close(cdb2_hndl_tp *hndl)
36083608 if (!hndl )
36093609 return 0 ;
36103610
3611+ if (hndl -> stmt_types ) {
3612+ free (hndl -> stmt_types );
3613+ hndl -> stmt_types = NULL ;
3614+ }
3615+
36113616 if (hndl -> fdb_hndl ) {
36123617 cdb2_close (hndl -> fdb_hndl );
36133618 hndl -> fdb_hndl = NULL ;
@@ -4241,6 +4246,94 @@ static inline void clear_snapshot_info(cdb2_hndl_tp *hndl, int line)
42414246 hndl -> is_retry = 0 ;
42424247}
42434248
4249+ static const struct {
4250+ const char * name ;
4251+ size_t name_sz ;
4252+ cdb2_coltype type ;
4253+ } all_types [] = {{"INTEGER" , sizeof ("INTEGER" ) - 1 , CDB2_INTEGER },
4254+ {"CSTRING" , sizeof ("CSTRING" ) - 1 , CDB2_CSTRING },
4255+ {"REAL" , sizeof ("REAL" ) - 1 , CDB2_REAL },
4256+ {"BLOB" , sizeof ("BLOB" ) - 1 , CDB2_BLOB },
4257+ {"DATETIME" , sizeof ("DATETIME" ) - 1 , CDB2_DATETIME },
4258+ {"DATETIMEUS" , sizeof ("DATETIMEUS" ) - 1 , CDB2_DATETIMEUS },
4259+ {"INTERVALDS" , sizeof ("INTERVALDS" ) - 1 , CDB2_INTERVALDS },
4260+ {"INTERVALDSUS" , sizeof ("INTERVALDSUS" ) - 1 , CDB2_INTERVALDSUS },
4261+ {"INTERVALYM" , sizeof ("INTERVALYM" ) - 1 , CDB2_INTERVALYM }};
4262+
4263+ static const int total_types = sizeof (all_types ) / sizeof (all_types [0 ]);
4264+
4265+ #define get_toklen (tok ) \
4266+ ({ \
4267+ cdb2_skipws(tok); \
4268+ int len = 0; \
4269+ while (tok[len] && !isspace(tok[len])) \
4270+ ++len; \
4271+ len; \
4272+ })
4273+
4274+ static int process_set_stmt_return_types (cdb2_hndl_tp * hndl , const char * sql )
4275+ {
4276+ int toklen ;
4277+ const char * tok = sql + 3 ; /* if we're here, first token is "set" */
4278+
4279+ toklen = get_toklen (tok );
4280+ if (toklen != 9 || strncasecmp (tok , "statement" , 9 ) != 0 )
4281+ return -1 ;
4282+ tok += toklen ;
4283+
4284+ toklen = get_toklen (tok );
4285+ if (toklen != 6 || strncasecmp (tok , "return" , 6 ) != 0 )
4286+ return -1 ;
4287+ tok += toklen ;
4288+
4289+ toklen = get_toklen (tok );
4290+ if (toklen != 5 || strncasecmp (tok , "types" , 5 ) != 0 )
4291+ return -1 ;
4292+ tok += toklen ;
4293+
4294+ if (hndl -> stmt_types ) {
4295+ sprintf (hndl -> errstr , "%s: already have %d parameter(s)" , __func__ , hndl -> stmt_types -> n );
4296+ return 1 ;
4297+ }
4298+
4299+ const int max_args = 1024 ;
4300+ uint8_t types [max_args ];
4301+ int count = 0 ;
4302+
4303+ while (1 ) {
4304+ toklen = get_toklen (tok );
4305+ if (toklen == 0 )
4306+ break ;
4307+ if (count == max_args ) {
4308+ sprintf (hndl -> errstr , "%s: max number of columns:%d" , __func__ , max_args );
4309+ return 1 ;
4310+ }
4311+ int i ;
4312+ for (i = 0 ; i < total_types ; ++ i ) {
4313+ if (toklen == all_types [i ].name_sz && strncasecmp (tok , all_types [i ].name , toklen ) == 0 ) {
4314+ tok += toklen ;
4315+ types [count ++ ] = all_types [i ].type ;
4316+ break ;
4317+ }
4318+ }
4319+ if (i >= total_types ) {
4320+ snprintf (hndl -> errstr , sizeof (hndl -> errstr ), "%s: column:%d has bad type:'%.*s'" , __func__ , count , toklen ,
4321+ tok );
4322+ return 1 ;
4323+ }
4324+ }
4325+ if (count == 0 ) {
4326+ sprintf (hndl -> errstr , "%s: bad number of columns:%d" , __func__ , count );
4327+ return 1 ;
4328+ }
4329+ hndl -> stmt_types = malloc (sizeof (struct cdb2_stmt_types ) + sizeof (int ) * count );
4330+ hndl -> stmt_types -> n = count ;
4331+ for (int i = 0 ; i < count ; ++ i ) {
4332+ hndl -> stmt_types -> types [i ] = types [i ];
4333+ }
4334+ return 0 ;
4335+ }
4336+
42444337static int process_set_command (cdb2_hndl_tp * hndl , const char * sql )
42454338{
42464339 int i , j , k ;
@@ -4253,8 +4346,10 @@ static int process_set_command(cdb2_hndl_tp *hndl, const char *sql)
42534346 return CDB2ERR_BADREQ ;
42544347 }
42554348
4256- int rc = process_ssl_set_command (hndl , sql );
4257- if (rc >= 0 )
4349+ int rc ;
4350+ if ((rc = process_ssl_set_command (hndl , sql )) >= 0 )
4351+ return rc ;
4352+ if ((rc = process_set_stmt_return_types (hndl , sql )) >= 0 )
42584353 return rc ;
42594354
42604355 i = hndl -> num_set_commands ;
@@ -4415,8 +4510,8 @@ static void attach_to_handle(cdb2_hndl_tp *child, cdb2_hndl_tp *parent)
44154510 child -> context_msgs .has_changed = child -> context_msgs .count > 0 ;
44164511}
44174512
4418- static int cdb2_run_statement_typed_int (cdb2_hndl_tp * hndl , const char * sql ,
4419- int ntypes , int * types , int line )
4513+ static int cdb2_run_statement_typed_int (cdb2_hndl_tp * hndl , const char * sql , int ntypes , int * types , int line ,
4514+ int * set_stmt )
44204515{
44214516 int return_value ;
44224517 int using_hint = 0 ;
@@ -4439,9 +4534,20 @@ static int cdb2_run_statement_typed_int(cdb2_hndl_tp *hndl, const char *sql,
44394534
44404535 /* sniff out 'set hasql on' here */
44414536 if (strncasecmp (sql , "set" , 3 ) == 0 ) {
4537+ * set_stmt = 1 ;
44424538 return process_set_command (hndl , sql );
44434539 }
44444540
4541+ if (hndl -> stmt_types ) {
4542+ if (ntypes || types ) {
4543+ sprintf (hndl -> errstr , "%s: provided %d type(s), but already have %d" , __func__ , ntypes ,
4544+ hndl -> stmt_types -> n );
4545+ return -1 ;
4546+ }
4547+ ntypes = hndl -> stmt_types -> n ;
4548+ types = hndl -> stmt_types -> types ;
4549+ }
4550+
44454551 if (strncasecmp (sql , "begin" , 5 ) == 0 ) {
44464552 debugprint ("setting is_begin flag\n" );
44474553 is_begin = 1 ;
@@ -5103,7 +5209,7 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
51035209{
51045210 int rc = 0 ;
51055211
5106- void * callbackrc ;
5212+ int set_stmt = 0 ;
51075213 int overwrite_rc = 0 ;
51085214 cdb2_event * e = NULL ;
51095215
@@ -5121,15 +5227,21 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
51215227
51225228 while ((e = cdb2_next_callback (hndl , CDB2_AT_ENTER_RUN_STATEMENT , e )) !=
51235229 NULL ) {
5124- callbackrc = cdb2_invoke_callback (hndl , e , 1 , CDB2_SQL , sql );
5230+ void * callbackrc = cdb2_invoke_callback (hndl , e , 1 , CDB2_SQL , sql );
51255231 PROCESS_EVENT_CTRL_BEFORE (hndl , e , rc , callbackrc , overwrite_rc );
51265232 }
51275233
5128- if (overwrite_rc )
5234+ if (overwrite_rc ) {
5235+ const char * first = sql ;
5236+ int len = get_toklen (first );
5237+ if (len == 3 && strncasecmp (first , "set" , 3 ) == 0 ) {
5238+ set_stmt = 1 ;
5239+ }
51295240 goto after_callback ;
5241+ }
51305242
51315243 if (hndl -> temp_trans && hndl -> in_trans ) {
5132- cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ );
5244+ cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ , & set_stmt );
51335245 }
51345246
51355247 hndl -> temp_trans = 0 ;
@@ -5138,7 +5250,7 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
51385250 (strncasecmp (sql , "set" , 3 ) != 0 && strncasecmp (sql , "begin" , 5 ) != 0 &&
51395251 strncasecmp (sql , "commit" , 6 ) != 0 &&
51405252 strncasecmp (sql , "rollback" , 8 ) != 0 )) {
5141- rc = cdb2_run_statement_typed_int (hndl , "begin" , 0 , NULL , __LINE__ );
5253+ rc = cdb2_run_statement_typed_int (hndl , "begin" , 0 , NULL , __LINE__ , & set_stmt );
51425254 if (rc ) {
51435255 debugprint ("cdb2_run_statement_typed_int rc = %d\n" , rc );
51445256 goto after_callback ;
@@ -5147,47 +5259,58 @@ int cdb2_run_statement_typed(cdb2_hndl_tp *hndl, const char *sql, int ntypes,
51475259 }
51485260
51495261 cdb2_skipws (sql );
5150- rc = cdb2_run_statement_typed_int (hndl , sql , ntypes , types , __LINE__ );
5262+ rc = cdb2_run_statement_typed_int (hndl , sql , ntypes , types , __LINE__ , & set_stmt );
51515263 if (rc )
51525264 debugprint ("rc = %d\n" , rc );
51535265
51545266 // XXX This code does not work correctly for WITH statements
51555267 // (they can be either read or write)
51565268 if (hndl -> temp_trans && !is_sql_read (sql )) {
51575269 if (rc == 0 ) {
5158- int commit_rc =
5159- cdb2_run_statement_typed_int (hndl , "commit" , 0 , NULL , __LINE__ );
5270+ int commit_rc = cdb2_run_statement_typed_int (hndl , "commit" , 0 , NULL , __LINE__ , & set_stmt );
51605271 debugprint ("rc = %d\n" , commit_rc );
51615272 rc = commit_rc ;
51625273 } else {
5163- cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ );
5274+ cdb2_run_statement_typed_int (hndl , "rollback" , 0 , NULL , __LINE__ , & set_stmt );
51645275 }
51655276 hndl -> temp_trans = 0 ;
51665277 }
51675278
51685279 if (log_calls ) {
5169- if (ntypes == 0 )
5280+ if (set_stmt || ( ntypes == 0 && hndl -> stmt_types == NULL ) )
51705281 fprintf (stderr , "%p> cdb2_run_statement(%p, \"%s\") = %d\n" ,
51715282 (void * )pthread_self (), hndl , sql , rc );
5172- else {
5283+ else if ( ntypes ) {
51735284 fprintf (stderr , "%p> cdb2_run_statement_typed(%p, \"%s\", [" ,
51745285 (void * )pthread_self (), hndl , sql );
51755286 for (int i = 0 ; i < ntypes ; i ++ ) {
51765287 fprintf (stderr , "%s%s" , cdb2_type_str (types [i ]),
51775288 i == ntypes - 1 ? "" : ", " );
51785289 }
51795290 fprintf (stderr , "] = %d\n" , rc );
5291+ } else {
5292+ int n = hndl -> stmt_types -> n ;
5293+ int * t = hndl -> stmt_types -> types ;
5294+ fprintf (stderr , "%p> cdb2_run_statement_typed(%p, \"%s\", [" , (void * )pthread_self (), hndl , sql );
5295+ for (int i = 0 ; i < n ; ++ i ) {
5296+ fprintf (stderr , "%s%s" , cdb2_type_str (t [i ]), i == n - 1 ? "" : ", " );
5297+ }
5298+ fprintf (stderr , "] = %d\n" , rc );
51805299 }
51815300 }
51825301
51835302after_callback :
51845303 while ((e = cdb2_next_callback (hndl , CDB2_AT_EXIT_RUN_STATEMENT , e )) !=
51855304 NULL ) {
5186- callbackrc = cdb2_invoke_callback (hndl , e , 2 , CDB2_SQL , sql ,
5187- CDB2_RETURN_VALUE , (intptr_t )rc );
5305+ void * callbackrc = cdb2_invoke_callback (hndl , e , 2 , CDB2_SQL , sql , CDB2_RETURN_VALUE , (intptr_t )rc );
51885306 PROCESS_EVENT_CTRL_AFTER (hndl , e , rc , callbackrc );
51895307 }
51905308
5309+ if (hndl -> stmt_types && !set_stmt ) {
5310+ free (hndl -> stmt_types );
5311+ hndl -> stmt_types = NULL ;
5312+ }
5313+
51915314 return rc ;
51925315}
51935316
0 commit comments