Skip to content

Commit 897d3cf

Browse files
authored
feat: handle secondary roles in the session (#83)
* rename SessionConf to SessionState * save session state * restrict the secondary roles * add more comments
1 parent 4d8d28a commit 897d3cf

File tree

4 files changed

+87
-32
lines changed

4 files changed

+87
-32
lines changed

query.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ type DataField struct {
2727
}
2828

2929
type QueryResponse struct {
30-
ID string `json:"id"`
31-
SessionID string `json:"session_id"`
32-
Session *SessionConfig `json:"session"`
33-
Schema []DataField `json:"schema"`
34-
Data [][]string `json:"data"`
35-
State string `json:"state"`
36-
Error *QueryError `json:"error"`
37-
Stats QueryStats `json:"stats"`
30+
ID string `json:"id"`
31+
SessionID string `json:"session_id"`
32+
Session *SessionState `json:"session"`
33+
Schema []DataField `json:"schema"`
34+
Data [][]string `json:"data"`
35+
State string `json:"state"`
36+
Error *QueryError `json:"error"`
37+
Stats QueryStats `json:"stats"`
3838
// TODO: Affect rows
3939
StatsURI string `json:"stats_uri"`
4040
FinalURI string `json:"final_uri"`
@@ -62,7 +62,7 @@ type QueryRequest struct {
6262
// We use client session instead of server session with session_id
6363
// SessionID string `json:"session_id,omitempty"`
6464

65-
Session *SessionConfig `json:"session,omitempty"`
65+
Session *SessionState `json:"session,omitempty"`
6666
SQL string `json:"sql"`
6767
Pagination *PaginationConfig `json:"pagination,omitempty"`
6868

@@ -80,9 +80,10 @@ type PaginationConfig struct {
8080
MaxRowsPerPage int64 `json:"max_rows_per_page,omitempty"`
8181
}
8282

83-
type SessionConfig struct {
84-
Database string `json:"database,omitempty"`
85-
Role string `json:"role,omitempty"`
83+
type SessionState struct {
84+
Database string `json:"database,omitempty"`
85+
Role string `json:"role,omitempty"`
86+
SecondaryRoles *[]string `json:"secondary_roles,omitempty"`
8687

8788
// Since we use client session, this should not be used
8889
// KeepServerSessionSecs uint64 `json:"keep_server_session_secs,omitempty"`

query_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package godatabend
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/test-go/testify/require"
9+
)
10+
11+
func Test_SessionState(t *testing.T) {
12+
ss := &SessionState{
13+
Database: "db1",
14+
Role: "",
15+
SecondaryRoles: nil,
16+
Settings: map[string]string{},
17+
}
18+
buf, err := json.Marshal(ss)
19+
require.NoError(t, err)
20+
assert.Equal(t, `{"database":"db1"}`, string(buf))
21+
22+
buf = []byte(`{"database":"db1", "secondary_roles": []}`)
23+
err = json.Unmarshal(buf, ss)
24+
require.NoError(t, err)
25+
assert.Equal(t, []string{}, *ss.SecondaryRoles)
26+
27+
buf = []byte(`{"database":"db1", "secondary_roles": null}`)
28+
err = json.Unmarshal(buf, ss)
29+
require.NoError(t, err)
30+
assert.Nil(t, ss.SecondaryRoles)
31+
32+
buf = []byte(`{"database":"db1"}`)
33+
err = json.Unmarshal(buf, ss)
34+
require.NoError(t, err)
35+
assert.Nil(t, ss.SecondaryRoles)
36+
}

restful.go

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,19 @@ func NewDefaultCopyOptions() map[string]string {
6363
type APIClient struct {
6464
cli *http.Client
6565

66-
apiEndpoint string
67-
host string
68-
tenant string
69-
warehouse string
70-
database string
71-
user string
72-
password string
73-
role string
74-
accessTokenLoader AccessTokenLoader
75-
sessionSettings map[string]string
66+
apiEndpoint string
67+
host string
68+
tenant string
69+
warehouse string
70+
database string
71+
user string
72+
password string
73+
role string
74+
secondaryRoles *[]string
75+
sessionSettings map[string]string
76+
7677
statsTracker QueryStatsTracker
78+
accessTokenLoader AccessTokenLoader
7779

7880
WaitTimeSeconds int64
7981
MaxRowsInBuffer int64
@@ -92,6 +94,19 @@ func NewAPIClientFromConfig(cfg *Config) *APIClient {
9294
default:
9395
apiScheme = "https"
9496
}
97+
98+
// if role is set in config, we'd prefer to limit it as the only effective role,
99+
// so you could limit the privileges by setting a role with limited privileges.
100+
// however this can be overridden by executing `SET SECONDARY ROLES ALL` in the
101+
// query.
102+
// secondaryRoles now have two viable values:
103+
// - nil: means enabling ALL the granted roles of the user
104+
// - []string{}: means enabling NONE of the granted roles
105+
var secondaryRoles *[]string
106+
if len(cfg.Role) > 0 {
107+
secondaryRoles = &[]string{}
108+
}
109+
95110
return &APIClient{
96111
cli: &http.Client{
97112
Timeout: cfg.Timeout,
@@ -104,6 +119,7 @@ func NewAPIClientFromConfig(cfg *Config) *APIClient {
104119
user: cfg.User,
105120
password: cfg.Password,
106121
role: cfg.Role,
122+
secondaryRoles: secondaryRoles,
107123
accessTokenLoader: initAccessTokenLoader(cfg),
108124
sessionSettings: cfg.Params,
109125
statsTracker: cfg.StatsTracker,
@@ -274,11 +290,12 @@ func (c *APIClient) getPagenationConfig() *PaginationConfig {
274290
}
275291
}
276292

277-
func (c *APIClient) getSessionConfig() *SessionConfig {
278-
return &SessionConfig{
279-
Database: c.database,
280-
Role: c.role,
281-
Settings: c.sessionSettings,
293+
func (c *APIClient) getSessionState() *SessionState {
294+
return &SessionState{
295+
Database: c.database,
296+
Role: c.role,
297+
SecondaryRoles: c.secondaryRoles,
298+
Settings: c.sessionSettings,
282299
}
283300
}
284301

@@ -290,7 +307,7 @@ func (c *APIClient) DoQuery(ctx context.Context, query string, args []driver.Val
290307
request := QueryRequest{
291308
SQL: q,
292309
Pagination: c.getPagenationConfig(),
293-
Session: c.getSessionConfig(),
310+
Session: c.getSessionState(),
294311
}
295312

296313
path := "/v1/query"
@@ -302,12 +319,12 @@ func (c *APIClient) DoQuery(ctx context.Context, query string, args []driver.Val
302319
if result.Error != nil {
303320
return nil, errors.Wrap(result.Error, "query error")
304321
}
305-
c.applySessionConfig(&result)
322+
c.applySessionState(&result)
306323
c.trackStats(&result)
307324
return &result, nil
308325
}
309326

310-
func (c *APIClient) applySessionConfig(response *QueryResponse) {
327+
func (c *APIClient) applySessionState(response *QueryResponse) {
311328
if response.Session == nil {
312329
return
313330
}
@@ -317,6 +334,7 @@ func (c *APIClient) applySessionConfig(response *QueryResponse) {
317334
if len(response.Session.Role) > 0 {
318335
c.role = response.Session.Role
319336
}
337+
c.secondaryRoles = response.Session.SecondaryRoles
320338
if response.Session.Settings != nil {
321339
newSessionSettings := map[string]string{}
322340
for k, v := range response.Session.Settings {
@@ -462,7 +480,7 @@ func (c *APIClient) InsertWithStage(ctx context.Context, sql string, stage *Stag
462480
request := QueryRequest{
463481
SQL: sql,
464482
Pagination: c.getPagenationConfig(),
465-
Session: c.getSessionConfig(),
483+
Session: c.getSessionState(),
466484
StageAttachment: &StageAttachmentConfig{
467485
Location: stage.String(),
468486
FileFormatOptions: fileFormatOptions,

restful_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestMakeHeadersUserPassword(t *testing.T) {
2222
assert.Nil(t, err)
2323
assert.Equal(t, headers["Authorization"], []string{"Basic cm9vdDpyb290"})
2424
assert.Equal(t, headers["X-Databend-Tenant"], []string{"default"})
25-
session := c.getSessionConfig()
25+
session := c.getSessionState()
2626
assert.Equal(t, session.Role, "role1")
2727
}
2828

0 commit comments

Comments
 (0)