Skip to content

Commit 7ff122d

Browse files
refactor after rebase
1 parent 04b417d commit 7ff122d

File tree

6 files changed

+61
-64
lines changed

6 files changed

+61
-64
lines changed

pkg/connector/databases.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ func (o *databaseBuilder) Grants(ctx context.Context, resource *v2.Resource, _ r
111111
return nil, nil, nil
112112
}
113113

114-
owner, ownerResp, err := o.client.GetAccountRole(ctx, database.Owner)
114+
owner, ownerStatusCode, err := o.client.GetAccountRole(ctx, database.Owner)
115115
if err != nil {
116-
if snowflake.IsUnprocessableEntity(ownerResp, err) {
116+
if snowflake.IsUnprocessableEntity(ownerStatusCode, err) {
117117
wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for database owner role %q (database %q): %w", database.Owner, resource.Id.Resource, err)
118118
return nil, nil, status.Error(codes.PermissionDenied, wrappedErr.Error())
119119
}

pkg/connector/tables.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ func (o *tableBuilder) isDBSharedOrSystem(ctx context.Context, resource *v2.Reso
6969
return val == "true" || val == "1", nil
7070
}
7171
}
72-
db, resp, err := o.client.GetDatabase(ctx, databaseName)
73-
if snowflake.IsUnprocessableEntity(resp, err) {
72+
db, statusCode, err := o.client.GetDatabase(ctx, databaseName)
73+
if snowflake.IsUnprocessableEntity(statusCode, err) {
7474
return true, nil
7575
}
7676
if err != nil {
@@ -148,7 +148,7 @@ func (o *tableBuilder) List(ctx context.Context, parentResourceID *v2.ResourceId
148148
}
149149

150150
const accountPageSize = 200
151-
tables, nextCursor, _, err := o.client.ListTablesInAccount(ctx, cursor, accountPageSize)
151+
tables, nextCursor, err := o.client.ListTablesInAccount(ctx, cursor, accountPageSize)
152152
if err != nil {
153153
return nil, nil, wrapError(err, "failed to list tables in account")
154154
}
@@ -280,9 +280,9 @@ func (o *tableBuilder) Grants(ctx context.Context, resource *v2.Resource, opts r
280280

281281
switch tg.GrantedTo {
282282
case grantedToRole:
283-
role, resp, err := o.client.GetAccountRole(ctx, tg.GranteeName)
283+
role, statusCode, err := o.client.GetAccountRole(ctx, tg.GranteeName)
284284
if err != nil {
285-
if snowflake.IsUnprocessableEntity(resp, err) {
285+
if snowflake.IsUnprocessableEntity(statusCode, err) {
286286
principalId, idErr := rs.NewResourceID(accountRoleResourceType, tg.GranteeName)
287287
if idErr != nil {
288288
continue
@@ -335,14 +335,14 @@ func (o *tableBuilder) Grants(ctx context.Context, resource *v2.Resource, opts r
335335
}
336336

337337
if ownerPrincipalID == nil {
338-
table, _, err := o.client.GetTable(ctx, databaseName, schemaName, tableName)
338+
table, err := o.client.GetTable(ctx, databaseName, schemaName, tableName)
339339
if err != nil {
340340
return nil, nil, wrapError(err, "failed to get table for owner fallback")
341341
}
342342
if table != nil && table.Owner != "" && table.Owner != "SNOWFLAKE" {
343-
owner, ownerResp, err := o.client.GetAccountRole(ctx, table.Owner)
343+
owner, ownerStatusCode, err := o.client.GetAccountRole(ctx, table.Owner)
344344
switch {
345-
case snowflake.IsUnprocessableEntity(ownerResp, err):
345+
case snowflake.IsUnprocessableEntity(ownerStatusCode, err):
346346
// system role, skip
347347
case err != nil:
348348
return nil, nil, wrapError(err, fmt.Sprintf("failed to get account role for table owner %q", table.Owner))

pkg/snowflake/account_role.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,33 +138,37 @@ func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) (
138138
return accountRoleGrantees, nil
139139
}
140140

141-
func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, error) {
141+
func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, int, error) {
142142
queries := []string{
143143
fmt.Sprintf("SHOW ROLES LIKE '%s' LIMIT 1;", roleName),
144144
}
145145

146146
req, err := c.PostStatementRequest(ctx, queries)
147147
if err != nil {
148-
return nil, err
148+
return nil, 0, err
149149
}
150150

151151
var response ListAccountRolesRawResponse
152152
resp, err := c.Do(req, uhttp.WithJSONResponse(&response))
153153
defer closeResponseBody(resp)
154154
if err != nil {
155-
return nil, err
155+
statusCode := 0
156+
if resp != nil {
157+
statusCode = resp.StatusCode
158+
}
159+
return nil, statusCode, err
156160
}
157161

158162
accountRoles, err := response.GetAccountRoles()
159163
if err != nil {
160-
return nil, err
164+
return nil, resp.StatusCode, err
161165
}
162166

163167
if len(accountRoles) == 0 {
164-
return nil, nil
168+
return nil, resp.StatusCode, nil
165169
}
166170

167-
return &accountRoles[0], nil
171+
return &accountRoles[0], resp.StatusCode, nil
168172
}
169173

170174
func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string) error {

pkg/snowflake/database.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,36 +101,37 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([
101101
return dbs, nil
102102
}
103103

104-
func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, error) {
104+
func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, int, error) {
105105
queries := []string{
106106
fmt.Sprintf("SHOW DATABASES LIKE '%s' LIMIT 1;", name),
107107
}
108108

109109
req, err := c.PostStatementRequest(ctx, queries)
110110
if err != nil {
111-
return nil, err
111+
return nil, 0, err
112112
}
113113

114114
var response ListDatabasesRawResponse
115115
resp, err := c.Do(req, uhttp.WithJSONResponse(&response))
116116
defer closeResponseBody(resp)
117117
if err != nil {
118-
if IsUnprocessableEntity(resp, err) {
119-
return nil, resp, nil
118+
statusCode := 0
119+
if resp != nil {
120+
statusCode = resp.StatusCode
120121
}
121-
return nil, resp, err
122+
return nil, statusCode, err
122123
}
123124

124125
databases, err := response.GetDatabases()
125126
if err != nil {
126-
return nil, err
127+
return nil, resp.StatusCode, err
127128
}
128129

129130
if len(databases) == 0 {
130-
return nil, fmt.Errorf("database with name %s not found", name)
131+
return nil, resp.StatusCode, fmt.Errorf("database with name %s not found", name)
131132
} else if len(databases) > 1 {
132-
return nil, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases))
133+
return nil, resp.StatusCode, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases))
133134
}
134135

135-
return &databases[0], nil
136+
return &databases[0], resp.StatusCode, nil
136137
}

pkg/snowflake/helper.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import (
88
// IsUnprocessableEntity reports whether the Snowflake API returned HTTP 422 (Unprocessable Entity).
99
// Snowflake returns 422 for certain operations on system/predefined objects (e.g. SHOW GRANTS OF ROLE for ACCOUNTADMIN,
1010
// SHOW ROLES LIKE for some roles). Callers can treat this as "no data" or "not resolvable" instead of a hard error.
11-
func IsUnprocessableEntity(resp *http.Response, err error) bool {
12-
if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity {
11+
func IsUnprocessableEntity(statusCode int, err error) bool {
12+
if statusCode == http.StatusUnprocessableEntity {
1313
return true
1414
}
1515
if err != nil && (strings.Contains(err.Error(), "422") || strings.Contains(err.Error(), "Unprocessable Entity")) {

pkg/snowflake/table.go

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (r *ListTablesRawResponse) ListTables() ([]Table, error) {
6161

6262
const tableListCursorSep = "\x00"
6363

64-
func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit int) ([]Table, string, *http.Response, error) {
64+
func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit int) ([]Table, string, error) {
6565
l := ctxzap.Extract(ctx)
6666

6767
var q string
@@ -81,51 +81,47 @@ func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit i
8181

8282
req, err := c.PostStatementRequest(ctx, queries)
8383
if err != nil {
84-
return nil, "", nil, err
84+
return nil, "", err
8585
}
8686

8787
var response ListTablesRawResponse
88-
resp, err := c.Do(req, uhttp.WithJSONResponse(&response))
88+
resp1, err := c.Do(req, uhttp.WithJSONResponse(&response))
89+
defer closeResponseBody(resp1)
8990
if err != nil {
90-
if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity {
91+
if resp1 != nil && resp1.StatusCode == http.StatusUnprocessableEntity {
9192
l.Debug("Insufficient privileges for SHOW TABLES IN ACCOUNT")
9293
wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for SHOW TABLES IN ACCOUNT: %w", err)
93-
return nil, "", nil, status.Error(codes.PermissionDenied, wrappedErr.Error())
94+
return nil, "", status.Error(codes.PermissionDenied, wrappedErr.Error())
9495
}
95-
return nil, "", nil, err
96-
}
97-
if resp != nil {
98-
defer resp.Body.Close()
96+
return nil, "", err
9997
}
10098

10199
req, err = c.GetStatementResponse(ctx, response.StatementHandle)
102100
if err != nil {
103-
return nil, "", resp, err
101+
return nil, "", err
104102
}
105-
resp, err = c.Do(req, uhttp.WithJSONResponse(&response))
103+
resp2, err := c.Do(req, uhttp.WithJSONResponse(&response))
104+
defer closeResponseBody(resp2)
106105
if err != nil {
107-
if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity {
106+
if resp2 != nil && resp2.StatusCode == http.StatusUnprocessableEntity {
108107
l.Debug("Insufficient privileges for SHOW TABLES IN ACCOUNT (statement result)")
109108
wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for SHOW TABLES IN ACCOUNT (statement result): %w", err)
110-
return nil, "", nil, status.Error(codes.PermissionDenied, wrappedErr.Error())
109+
return nil, "", status.Error(codes.PermissionDenied, wrappedErr.Error())
111110
}
112-
return nil, "", resp, err
113-
}
114-
if resp != nil {
115-
defer resp.Body.Close()
111+
return nil, "", err
116112
}
117113

118114
tables, err := response.ListTables()
119115
if err != nil {
120-
return nil, "", resp, err
116+
return nil, "", err
121117
}
122118

123119
var nextCursor string
124120
if len(tables) >= limit {
125121
last := tables[len(tables)-1]
126122
nextCursor = last.DatabaseName + tableListCursorSep + last.SchemaName + tableListCursorSep + last.Name
127123
}
128-
return tables, nextCursor, resp, nil
124+
return tables, nextCursor, nil
129125
}
130126

131127
// escapeSingleQuote doubles single quotes for use inside SQL string literals.
@@ -149,54 +145,50 @@ func escapeDoubleQuotedIdentifier(s string) string {
149145
return strings.ReplaceAll(s, `"`, `""`)
150146
}
151147

152-
func (c *Client) GetTable(ctx context.Context, database, schema, tableName string) (*Table, *http.Response, error) {
148+
func (c *Client) GetTable(ctx context.Context, database, schema, tableName string) (*Table, error) {
153149
likePattern := escapeLikePattern(tableName)
154150
queries := []string{
155151
fmt.Sprintf("SHOW TABLES LIKE '%s' ESCAPE '\\' IN SCHEMA \"%s\".\"%s\" LIMIT 1;", likePattern, escapeDoubleQuotedIdentifier(database), escapeDoubleQuotedIdentifier(schema)),
156152
}
157153

158154
req, err := c.PostStatementRequest(ctx, queries)
159155
if err != nil {
160-
return nil, nil, err
156+
return nil, err
161157
}
162158

163159
var response ListTablesRawResponse
164-
resp, err := c.Do(req, uhttp.WithJSONResponse(&response))
160+
resp1, err := c.Do(req, uhttp.WithJSONResponse(&response))
161+
defer closeResponseBody(resp1)
165162
if err != nil {
166-
if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity {
167-
return nil, resp, nil
163+
if resp1 != nil && resp1.StatusCode == http.StatusUnprocessableEntity {
164+
return nil, nil
168165
}
169-
return nil, nil, err
170-
}
171-
if resp != nil {
172-
defer resp.Body.Close()
166+
return nil, err
173167
}
174168

175169
req, err = c.GetStatementResponse(ctx, response.StatementHandle)
176170
if err != nil {
177-
return nil, resp, err
171+
return nil, err
178172
}
179-
resp, err = c.Do(req, uhttp.WithJSONResponse(&response))
173+
resp2, err := c.Do(req, uhttp.WithJSONResponse(&response))
174+
defer closeResponseBody(resp2)
180175
if err != nil {
181-
return nil, resp, err
182-
}
183-
if resp != nil {
184-
defer resp.Body.Close()
176+
return nil, err
185177
}
186178

187179
tables, err := response.ListTables()
188180
if err != nil {
189-
return nil, resp, err
181+
return nil, err
190182
}
191183

192184
// Filter by exact match (database, schema, and name)
193185
for _, table := range tables {
194186
if table.DatabaseName == database && table.SchemaName == schema && table.Name == tableName {
195-
return &table, resp, nil
187+
return &table, nil
196188
}
197189
}
198190

199-
return nil, resp, fmt.Errorf("table %s.%s.%s not found", database, schema, tableName)
191+
return nil, fmt.Errorf("table %s.%s.%s not found", database, schema, tableName)
200192
}
201193

202194
var tableGrantStructFieldToColumnMap = map[string]string{

0 commit comments

Comments
 (0)