Skip to content

Commit f00d210

Browse files
authored
feat: add route hint support (#130)
* refactor: add TxnState * add route hint header * chore: pass rout hint in headers * reset route hint before start query * save hint * use equal fold * add comment
1 parent 8893ae1 commit f00d210

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

const.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ const (
44
DatabendTenantHeader = "X-DATABEND-TENANT"
55
DatabendWarehouseHeader = "X-DATABEND-WAREHOUSE"
66
DatabendQueryIDHeader = "X-DATABEND-QUERY-ID"
7+
DatabendRouteHintHeader = "X-DATABEND-ROUTE-HINT"
78
DatabendQueryIDNode = "X-DATABEND-NODE-ID"
89
Authorization = "Authorization"
910
WarehouseRoute = "X-DATABEND-ROUTE"

query.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ type PaginationConfig struct {
8686
MaxRowsPerPage int64 `json:"max_rows_per_page,omitempty"`
8787
}
8888

89+
type TxnState string
90+
91+
const (
92+
TxnStateActive TxnState = "Active"
93+
TxnStateAutoCommit TxnState = "AutoCommit"
94+
)
95+
8996
type SessionState struct {
9097
Database string `json:"database,omitempty"`
9198
Role string `json:"role,omitempty"`
@@ -97,7 +104,7 @@ type SessionState struct {
97104
Settings map[string]string `json:"settings,omitempty"`
98105

99106
// txn
100-
TxnState string `json:"txn_state,omitempty"`
107+
TxnState TxnState `json:"txn_state,omitempty"` // "Active", "AutoCommit"
101108
}
102109

103110
type StageAttachmentConfig struct {

restful.go

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ type APIClient struct {
9595
sessionStateRaw *json.RawMessage
9696
sessionState *SessionState
9797

98+
// routHint is used to save the route hint from the last responded X-Databend-Route-Hint, this is
99+
// used for guiding the preferred route for the next following http requests, this is useful for
100+
// some cases like query pagination & multi-statements transaction.
101+
routeHint string
102+
98103
statsTracker QueryStatsTracker
99104
accessTokenLoader AccessTokenLoader
100105

@@ -193,7 +198,7 @@ func initAccessTokenLoader(cfg *Config) AccessTokenLoader {
193198
return nil
194199
}
195200

196-
func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, resp interface{}) error {
201+
func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, resp interface{}, respHeaders *http.Header) error {
197202
if c.doRequestFunc != nil {
198203
return c.doRequestFunc(method, path, req, resp)
199204
}
@@ -267,6 +272,9 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte
267272
}
268273
}
269274
}
275+
if respHeaders != nil {
276+
*respHeaders = httpResp.Header
277+
}
270278
return nil
271279
}
272280
return errors.Errorf("failed to do request after %d retries", maxRetries)
@@ -301,7 +309,6 @@ func (c *APIClient) makeHeaders(ctx context.Context) (http.Header, error) {
301309
headers.Set(UserAgent, fmt.Sprintf("%s/databend-go/%s", version, userAgent))
302310
} else {
303311
headers.Set(UserAgent, fmt.Sprintf("databend-go/%s", version))
304-
305312
}
306313
headers.Set(UserAgent, fmt.Sprintf("databend-go/%s", version))
307314
if c.tenant != "" {
@@ -310,6 +317,9 @@ func (c *APIClient) makeHeaders(ctx context.Context) (http.Header, error) {
310317
if c.warehouse != "" {
311318
headers.Set(DatabendWarehouseHeader, c.warehouse)
312319
}
320+
if c.routeHint != "" {
321+
headers.Set(DatabendRouteHintHeader, c.routeHint)
322+
}
313323

314324
if queryID, ok := ctx.Value(ContextKeyQueryID).(string); ok {
315325
headers.Set(DatabendQueryIDHeader, queryID)
@@ -367,6 +377,10 @@ func (c *APIClient) getSessionState() *SessionState {
367377
return c.sessionState
368378
}
369379

380+
func (c *APIClient) inActiveTransaction() bool {
381+
return c.sessionState != nil && strings.EqualFold(string(c.sessionState.TxnState), string(TxnStateActive))
382+
}
383+
370384
func (c *APIClient) applySessionState(response *QueryResponse) {
371385
if response == nil || response.Session == nil {
372386
return
@@ -422,7 +436,7 @@ func (c *APIClient) QuerySync(ctx context.Context, query string, args []driver.V
422436
return c.PollUntilQueryEnd(ctx, resp)
423437
}
424438

425-
func (c *APIClient) DoRetry(f retry.RetryableFunc, t RequestType) error {
439+
func (c *APIClient) doRetry(f retry.RetryableFunc, t RequestType) error {
426440
var delay time.Duration = 1
427441
var attempts uint = 3
428442
if t == Query {
@@ -458,20 +472,32 @@ func (c *APIClient) startQueryRequest(ctx context.Context, request *QueryRequest
458472
c.NextQuery()
459473
// fmt.Printf("start query %v %v\n", c.GetQueryID(), request.SQL)
460474

475+
if !c.inActiveTransaction() {
476+
c.routeHint = ""
477+
}
478+
461479
path := "/v1/query"
462-
var resp QueryResponse
463-
err := c.DoRetry(func() error {
464-
return c.doRequest(ctx, "POST", path, request, &resp)
480+
var (
481+
resp QueryResponse
482+
respHeaders http.Header
483+
)
484+
err := c.doRetry(func() error {
485+
return c.doRequest(ctx, "POST", path, request, &resp, &respHeaders)
465486
}, Query,
466487
)
467488
if err != nil {
468489
return nil, errors.Wrap(err, "failed to do query request")
469490
}
491+
492+
c.NodeID = resp.NodeID
493+
c.trackStats(&resp)
470494
// try update session as long as resp is not nil, even if query failed (resp.Error != nil)
471495
// e.g. transaction state need to be updated if commit fail
472496
c.applySessionState(&resp)
473-
c.NodeID = resp.NodeID
474-
c.trackStats(&resp)
497+
// save route hint for the next following http requests
498+
if len(respHeaders) > 0 {
499+
c.routeHint = respHeaders.Get(DatabendRouteHintHeader)
500+
}
475501
return &resp, nil
476502
}
477503

@@ -490,9 +516,9 @@ func (c *APIClient) StartQuery(ctx context.Context, query string, args []driver.
490516

491517
func (c *APIClient) PollQuery(ctx context.Context, nextURI string) (*QueryResponse, error) {
492518
var result QueryResponse
493-
err := c.DoRetry(
519+
err := c.doRetry(
494520
func() error {
495-
return c.doRequest(ctx, "GET", nextURI, nil, &result)
521+
return c.doRequest(ctx, "GET", nextURI, nil, &result, nil)
496522
},
497523
Page,
498524
)
@@ -510,8 +536,8 @@ func (c *APIClient) KillQuery(ctx context.Context, response *QueryResponse) erro
510536
if response != nil && response.KillURI != "" {
511537
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
512538
defer cancel()
513-
_ = c.DoRetry(func() error {
514-
return c.doRequest(ctx, "GET", response.KillURI, nil, nil)
539+
_ = c.doRetry(func() error {
540+
return c.doRequest(ctx, "GET", response.KillURI, nil, nil, nil)
515541
}, Kill,
516542
)
517543
}
@@ -522,8 +548,8 @@ func (c *APIClient) CloseQuery(ctx context.Context, response *QueryResponse) err
522548
if response != nil && response.FinalURI != "" {
523549
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
524550
defer cancel()
525-
_ = c.DoRetry(func() error {
526-
return c.doRequest(ctx, "GET", response.FinalURI, nil, nil)
551+
_ = c.doRetry(func() error {
552+
return c.doRequest(ctx, "GET", response.FinalURI, nil, nil, nil)
527553
}, Final,
528554
)
529555
}

0 commit comments

Comments
 (0)