Skip to content

Commit 47ac2ed

Browse files
authored
feat: support query forward and temp table (#143)
* support cookie * ci: TestTempTable() * feat: support query forward. * feat: support logout * test logout * fix * ci: use image datafuselabs/databend:nightly
1 parent a06d3bc commit 47ac2ed

File tree

7 files changed

+106
-9
lines changed

7 files changed

+106
-9
lines changed

client.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,26 @@ func (c *APIClient) GetQueryID() string {
127127
return fmt.Sprintf("%s.%d", c.SessionID, c.QuerySeq)
128128
}
129129

130+
func (c *APIClient) NeedSticky() bool {
131+
if c.sessionState != nil {
132+
return c.sessionState.NeedSticky
133+
}
134+
return false
135+
}
136+
137+
func (c *APIClient) NeedKeepAlive() bool {
138+
if c.sessionState != nil {
139+
return c.sessionState.NeedKeepAlive
140+
}
141+
return false
142+
}
143+
130144
func NewAPIHttpClientFromConfig(cfg *Config) *http.Client {
145+
jar := NewIgnoreDomainCookieJar()
146+
jar.SetCookies(nil, []*http.Cookie{{Name: "cookie_enabled", Value: "true"}})
131147
cli := &http.Client{
132148
Timeout: cfg.Timeout,
149+
Jar: jar,
133150
}
134151
if cfg.EnableOpenTelemetry {
135152
cli.Transport = otelhttp.NewTransport(http.DefaultTransport)
@@ -148,7 +165,7 @@ func NewAPIClientFromConfig(cfg *Config) *APIClient {
148165

149166
// if role is set in config, we'd prefer to limit it as the only effective role,
150167
// so you could limit the privileges by setting a role with limited privileges.
151-
// however this can be overridden by executing `SET SECONDARY ROLES ALL` in the
168+
// however, this can be overridden by executing `SET SECONDARY ROLES ALL` in the
152169
// query.
153170
// secondaryRoles now have two viable values:
154171
// - nil: means enabling ALL the granted roles of the user
@@ -202,7 +219,7 @@ func initAccessTokenLoader(cfg *Config) AccessTokenLoader {
202219
return nil
203220
}
204221

205-
func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, resp interface{}, respHeaders *http.Header) error {
222+
func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, needSticky bool, resp interface{}, respHeaders *http.Header) error {
206223
if c.doRequestFunc != nil {
207224
return c.doRequestFunc(method, path, req, resp)
208225
}
@@ -226,6 +243,9 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte
226243
maxRetries := 2
227244
for i := 1; i <= maxRetries; i++ {
228245
headers, err := c.makeHeaders(ctx)
246+
if needSticky {
247+
headers.Set(DatabendQueryStickyNode, c.NodeID)
248+
}
229249
if err != nil {
230250
return errors.Wrap(err, "failed to make request headers")
231251
}
@@ -484,7 +504,7 @@ func (c *APIClient) startQueryRequest(ctx context.Context, request *QueryRequest
484504
respHeaders http.Header
485505
)
486506
err := c.doRetry(func() error {
487-
return c.doRequest(ctx, "POST", path, request, &resp, &respHeaders)
507+
return c.doRequest(ctx, "POST", path, request, c.NeedSticky(), &resp, &respHeaders)
488508
}, Query,
489509
)
490510
if err != nil {
@@ -520,7 +540,7 @@ func (c *APIClient) PollQuery(ctx context.Context, nextURI string) (*QueryRespon
520540
var result QueryResponse
521541
err := c.doRetry(
522542
func() error {
523-
return c.doRequest(ctx, "GET", nextURI, nil, &result, nil)
543+
return c.doRequest(ctx, "GET", nextURI, nil, true, &result, nil)
524544
},
525545
Page,
526546
)
@@ -539,7 +559,7 @@ func (c *APIClient) KillQuery(ctx context.Context, response *QueryResponse) erro
539559
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
540560
defer cancel()
541561
_ = c.doRetry(func() error {
542-
return c.doRequest(ctx, "GET", response.KillURI, nil, nil, nil)
562+
return c.doRequest(ctx, "GET", response.KillURI, nil, true, nil, nil)
543563
}, Kill,
544564
)
545565
}
@@ -551,7 +571,7 @@ func (c *APIClient) CloseQuery(ctx context.Context, response *QueryResponse) err
551571
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
552572
defer cancel()
553573
_ = c.doRetry(func() error {
554-
return c.doRequest(ctx, "GET", response.FinalURI, nil, nil, nil)
574+
return c.doRequest(ctx, "GET", response.FinalURI, nil, true, nil, nil)
555575
}, Final,
556576
)
557577
}
@@ -723,6 +743,14 @@ func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation
723743
return nil
724744
}
725745

746+
func (c *APIClient) Logout(ctx context.Context) error {
747+
if c.NeedKeepAlive() {
748+
req := &struct{}{}
749+
return c.doRequest(ctx, "POST", "/v1/session/logout/", req, c.NeedSticky(), nil, nil)
750+
}
751+
return nil
752+
}
753+
726754
func randRouteHint() string {
727755
charset := "abcdef0123456789"
728756
b := make([]byte, 16)

connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func (dc *DatabendConn) BeginTx(
7878
}
7979

8080
func (dc *DatabendConn) cleanup() {
81-
// must flush log buffer while the process is running.
81+
dc.rest.Logout(dc.ctx)
8282
dc.rest = nil
8383
dc.cfg = nil
8484
}

const.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ const (
66
DatabendQueryIDHeader = "X-DATABEND-QUERY-ID"
77
DatabendRouteHintHeader = "X-DATABEND-ROUTE-HINT"
88
DatabendQueryIDNode = "X-DATABEND-NODE-ID"
9+
DatabendQueryStickyNode = "X-DATABEND-STICKY-NODE"
910
Authorization = "Authorization"
1011
WarehouseRoute = "X-DATABEND-ROUTE"
1112
UserAgent = "User-Agent"

cookie_jar.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package godatabend
2+
3+
import (
4+
"net/http"
5+
"net/url"
6+
"sync"
7+
)
8+
9+
type IgnoreDomainCookieJar struct {
10+
mu sync.Mutex
11+
cookies map[string]*http.Cookie
12+
}
13+
14+
func NewIgnoreDomainCookieJar() *IgnoreDomainCookieJar {
15+
return &IgnoreDomainCookieJar{
16+
cookies: make(map[string]*http.Cookie),
17+
}
18+
}
19+
20+
func (jar *IgnoreDomainCookieJar) SetCookies(_u *url.URL, cookies []*http.Cookie) {
21+
jar.mu.Lock()
22+
defer jar.mu.Unlock()
23+
for _, cookie := range cookies {
24+
jar.cookies[cookie.Name] = cookie
25+
}
26+
}
27+
28+
func (jar *IgnoreDomainCookieJar) Cookies(u *url.URL) []*http.Cookie {
29+
jar.mu.Lock()
30+
defer jar.mu.Unlock()
31+
result := make([]*http.Cookie, 0, len(jar.cookies))
32+
for _, cookie := range jar.cookies {
33+
result = append(result, cookie)
34+
}
35+
return result
36+
}

query.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ type SessionState struct {
104104
Settings map[string]string `json:"settings,omitempty"`
105105

106106
// txn
107-
TxnState TxnState `json:"txn_state,omitempty"` // "Active", "AutoCommit"
107+
TxnState TxnState `json:"txn_state,omitempty"` // "Active", "AutoCommit"
108+
NeedSticky bool `json:"need_sticky,omitempty"`
109+
NeedKeepAlive bool `json:"need_keep_alive,omitempty"`
108110
}
109111

110112
type StageAttachmentConfig struct {

tests/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ services:
66
volumes:
77
- ./data:/data
88
databend:
9-
image: datafuselabs/databend
9+
image: datafuselabs/databend:nightly
1010
environment:
1111
- QUERY_DEFAULT_USER=databend
1212
- QUERY_DEFAULT_PASSWORD=databend

tests/session_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tests
22

33
import (
4+
"context"
45
"database/sql"
56
"fmt"
67
"github.com/stretchr/testify/require"
@@ -95,3 +96,32 @@ func (s *DatabendTestSuite) TestSessionVariable() {
9596
r.Nil(err)
9697
r.Equal(int64(100), result)
9798
}
99+
100+
func (s *DatabendTestSuite) TestTempTable() {
101+
r := require.New(s.T())
102+
103+
var result int64
104+
ctx := context.Background()
105+
conn, err := s.db.Conn(ctx)
106+
defer func() {
107+
err = conn.Close()
108+
r.Nil(err)
109+
}()
110+
_, err = conn.ExecContext(ctx, "create temp table t_temp (a int64)")
111+
r.Nil(err)
112+
_, err = conn.ExecContext(ctx, "insert into t_temp values (1), (2)")
113+
r.Nil(err)
114+
rows, err := conn.QueryContext(ctx, "select * from t_temp")
115+
r.Nil(err)
116+
defer rows.Close()
117+
118+
r.True(rows.Next())
119+
err = rows.Scan(&result)
120+
r.Equal(int64(1), result)
121+
122+
r.True(rows.Next())
123+
err = rows.Scan(&result)
124+
r.Equal(int64(2), result)
125+
126+
r.False(rows.Next())
127+
}

0 commit comments

Comments
 (0)