@@ -95,6 +95,11 @@ type APIClient struct {
9595 sessionStateRaw * json.RawMessage
9696 sessionState * SessionState
9797
98+ // routHint is used to save the route hint from the last responded X-Databend-Route-Hint, this is
99+ // used for guiding the preferred route for the next following http requests, this is useful for
100+ // some cases like query pagination & multi-statements transaction.
101+ routeHint string
102+
98103 statsTracker QueryStatsTracker
99104 accessTokenLoader AccessTokenLoader
100105
@@ -193,7 +198,7 @@ func initAccessTokenLoader(cfg *Config) AccessTokenLoader {
193198 return nil
194199}
195200
196- func (c * APIClient ) doRequest (ctx context.Context , method , path string , req interface {}, resp interface {}) error {
201+ func (c * APIClient ) doRequest (ctx context.Context , method , path string , req interface {}, resp interface {}, respHeaders * http. Header ) error {
197202 if c .doRequestFunc != nil {
198203 return c .doRequestFunc (method , path , req , resp )
199204 }
@@ -267,6 +272,9 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte
267272 }
268273 }
269274 }
275+ if respHeaders != nil {
276+ * respHeaders = httpResp .Header
277+ }
270278 return nil
271279 }
272280 return errors .Errorf ("failed to do request after %d retries" , maxRetries )
@@ -301,7 +309,6 @@ func (c *APIClient) makeHeaders(ctx context.Context) (http.Header, error) {
301309 headers .Set (UserAgent , fmt .Sprintf ("%s/databend-go/%s" , version , userAgent ))
302310 } else {
303311 headers .Set (UserAgent , fmt .Sprintf ("databend-go/%s" , version ))
304-
305312 }
306313 headers .Set (UserAgent , fmt .Sprintf ("databend-go/%s" , version ))
307314 if c .tenant != "" {
@@ -310,6 +317,9 @@ func (c *APIClient) makeHeaders(ctx context.Context) (http.Header, error) {
310317 if c .warehouse != "" {
311318 headers .Set (DatabendWarehouseHeader , c .warehouse )
312319 }
320+ if c .routeHint != "" {
321+ headers .Set (DatabendRouteHintHeader , c .routeHint )
322+ }
313323
314324 if queryID , ok := ctx .Value (ContextKeyQueryID ).(string ); ok {
315325 headers .Set (DatabendQueryIDHeader , queryID )
@@ -367,6 +377,10 @@ func (c *APIClient) getSessionState() *SessionState {
367377 return c .sessionState
368378}
369379
380+ func (c * APIClient ) inActiveTransaction () bool {
381+ return c .sessionState != nil && strings .EqualFold (string (c .sessionState .TxnState ), string (TxnStateActive ))
382+ }
383+
370384func (c * APIClient ) applySessionState (response * QueryResponse ) {
371385 if response == nil || response .Session == nil {
372386 return
@@ -422,7 +436,7 @@ func (c *APIClient) QuerySync(ctx context.Context, query string, args []driver.V
422436 return c .PollUntilQueryEnd (ctx , resp )
423437}
424438
425- func (c * APIClient ) DoRetry (f retry.RetryableFunc , t RequestType ) error {
439+ func (c * APIClient ) doRetry (f retry.RetryableFunc , t RequestType ) error {
426440 var delay time.Duration = 1
427441 var attempts uint = 3
428442 if t == Query {
@@ -458,20 +472,32 @@ func (c *APIClient) startQueryRequest(ctx context.Context, request *QueryRequest
458472 c .NextQuery ()
459473 // fmt.Printf("start query %v %v\n", c.GetQueryID(), request.SQL)
460474
475+ if ! c .inActiveTransaction () {
476+ c .routeHint = ""
477+ }
478+
461479 path := "/v1/query"
462- var resp QueryResponse
463- err := c .DoRetry (func () error {
464- return c .doRequest (ctx , "POST" , path , request , & resp )
480+ var (
481+ resp QueryResponse
482+ respHeaders http.Header
483+ )
484+ err := c .doRetry (func () error {
485+ return c .doRequest (ctx , "POST" , path , request , & resp , & respHeaders )
465486 }, Query ,
466487 )
467488 if err != nil {
468489 return nil , errors .Wrap (err , "failed to do query request" )
469490 }
491+
492+ c .NodeID = resp .NodeID
493+ c .trackStats (& resp )
470494 // try update session as long as resp is not nil, even if query failed (resp.Error != nil)
471495 // e.g. transaction state need to be updated if commit fail
472496 c .applySessionState (& resp )
473- c .NodeID = resp .NodeID
474- c .trackStats (& resp )
497+ // save route hint for the next following http requests
498+ if len (respHeaders ) > 0 {
499+ c .routeHint = respHeaders .Get (DatabendRouteHintHeader )
500+ }
475501 return & resp , nil
476502}
477503
@@ -490,9 +516,9 @@ func (c *APIClient) StartQuery(ctx context.Context, query string, args []driver.
490516
491517func (c * APIClient ) PollQuery (ctx context.Context , nextURI string ) (* QueryResponse , error ) {
492518 var result QueryResponse
493- err := c .DoRetry (
519+ err := c .doRetry (
494520 func () error {
495- return c .doRequest (ctx , "GET" , nextURI , nil , & result )
521+ return c .doRequest (ctx , "GET" , nextURI , nil , & result , nil )
496522 },
497523 Page ,
498524 )
@@ -510,8 +536,8 @@ func (c *APIClient) KillQuery(ctx context.Context, response *QueryResponse) erro
510536 if response != nil && response .KillURI != "" {
511537 ctx , cancel := context .WithTimeout (ctx , 30 * time .Second )
512538 defer cancel ()
513- _ = c .DoRetry (func () error {
514- return c .doRequest (ctx , "GET" , response .KillURI , nil , nil )
539+ _ = c .doRetry (func () error {
540+ return c .doRequest (ctx , "GET" , response .KillURI , nil , nil , nil )
515541 }, Kill ,
516542 )
517543 }
@@ -522,8 +548,8 @@ func (c *APIClient) CloseQuery(ctx context.Context, response *QueryResponse) err
522548 if response != nil && response .FinalURI != "" {
523549 ctx , cancel := context .WithTimeout (ctx , 30 * time .Second )
524550 defer cancel ()
525- _ = c .DoRetry (func () error {
526- return c .doRequest (ctx , "GET" , response .FinalURI , nil , nil )
551+ _ = c .doRetry (func () error {
552+ return c .doRequest (ctx , "GET" , response .FinalURI , nil , nil , nil )
527553 }, Final ,
528554 )
529555 }
0 commit comments