Skip to content

Commit cbb9ee1

Browse files
committed
POC planbuilder authorization
1 parent c5725b1 commit cbb9ee1

File tree

10 files changed

+153
-164
lines changed

10 files changed

+153
-164
lines changed

enginetest/queries/priv_auth_queries.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,10 +1780,10 @@ var UserPrivTests = []UserPrivilegeTest{
17801780
},
17811781
},
17821782
{
1783-
User: "rand_user1",
1784-
Host: "54.244.85.252",
1785-
Query: "SELECT * FROM mydb.test;",
1786-
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
1783+
User: "rand_user1",
1784+
Host: "54.244.85.252",
1785+
Query: "SELECT * FROM mydb.test;",
1786+
ExpectedErrStr: "Access denied for user 'rand_user1' (errno 1045) (sqlstate 28000)",
17871787
},
17881788
{
17891789
User: "rand_user2",
@@ -1804,10 +1804,10 @@ var UserPrivTests = []UserPrivilegeTest{
18041804
},
18051805
},
18061806
{
1807-
User: "rand_user2",
1808-
Host: "54.244.85.252",
1809-
Query: "SELECT * FROM mydb.test2;",
1810-
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
1807+
User: "rand_user2",
1808+
Host: "54.244.85.252",
1809+
Query: "SELECT * FROM mydb.test2;",
1810+
ExpectedErrStr: "Access denied for user 'rand_user2' (errno 1045) (sqlstate 28000)",
18111811
},
18121812
},
18131813
},

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,6 @@ require (
4343
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
4444
)
4545

46+
replace github.com/dolthub/vitess => ../vitess
47+
4648
go 1.22.2

sql/mysql_db/privileged_database_provider.go

Lines changed: 2 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -164,58 +164,12 @@ func (pdb PrivilegedDatabase) Name() string {
164164

165165
// GetTableInsensitive implements the interface sql.Database.
166166
func (pdb PrivilegedDatabase) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) {
167-
checkName := pdb.db.Name()
168-
if adb, ok := pdb.db.(sql.AliasedDatabase); ok {
169-
checkName = adb.AliasedName()
170-
}
171-
172-
privSet := pdb.grantTables.UserActivePrivilegeSet(ctx)
173-
dbSet := privSet.Database(checkName)
174-
// If there are no usable privileges for this database then the table is inaccessible.
175-
if privSet.Count() == 0 && !dbSet.HasPrivileges() {
176-
return nil, false, sql.ErrDatabaseAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), checkName)
177-
}
178-
179-
tblSet := dbSet.Table(tblName)
180-
// If the user has no global static privileges, database-level privileges, or table-relevant privileges then the
181-
// table is not accessible.
182-
if privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() {
183-
return nil, false, sql.ErrTableAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), tblName)
184-
}
185167
return pdb.db.GetTableInsensitive(ctx, tblName)
186168
}
187169

188170
// GetTableNames implements the interface sql.Database.
189171
func (pdb PrivilegedDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
190-
var tablesWithAccess []string
191-
var err error
192-
privSet := pdb.grantTables.UserActivePrivilegeSet(ctx)
193-
194-
checkName := pdb.db.Name()
195-
if adb, ok := pdb.db.(sql.AliasedDatabase); ok {
196-
checkName = adb.AliasedName()
197-
}
198-
199-
dbSet := privSet.Database(checkName)
200-
// If there are no usable privileges for this database then no table is accessible.
201-
privSetCount := privSet.Count()
202-
if privSetCount == 0 && !dbSet.HasPrivileges() {
203-
return nil, nil
204-
}
205-
206-
tblNames, err := pdb.db.GetTableNames(ctx)
207-
if err != nil {
208-
return nil, err
209-
}
210-
dbSetCount := dbSet.Count()
211-
for _, tblName := range tblNames {
212-
// If the user has any global static privileges, database-level privileges, or table-relevant privileges then a
213-
// table is accessible.
214-
if privSetCount > 0 || dbSetCount > 0 || dbSet.Table(tblName).HasPrivileges() {
215-
tablesWithAccess = append(tablesWithAccess, tblName)
216-
}
217-
}
218-
return tablesWithAccess, nil
172+
return pdb.db.GetTableNames(ctx)
219173
}
220174

221175
// GetTableInsensitiveAsOf returns a new sql.VersionedDatabase.
@@ -224,26 +178,6 @@ func (pdb PrivilegedDatabase) GetTableInsensitiveAsOf(ctx *sql.Context, tblName
224178
if !ok {
225179
return nil, false, sql.ErrAsOfNotSupported.New(pdb.db.Name())
226180
}
227-
228-
privSet := pdb.grantTables.UserActivePrivilegeSet(ctx)
229-
230-
checkName := pdb.db.Name()
231-
if adb, ok := pdb.db.(sql.AliasedDatabase); ok {
232-
checkName = adb.AliasedName()
233-
}
234-
235-
dbSet := privSet.Database(checkName)
236-
// If there are no usable privileges for this database then the table is inaccessible.
237-
if privSet.Count() == 0 && !dbSet.HasPrivileges() {
238-
return nil, false, sql.ErrDatabaseAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), checkName)
239-
}
240-
241-
tblSet := dbSet.Table(tblName)
242-
// If the user has no global static privileges, database-level privileges, or table-relevant privileges then the
243-
// table is not accessible.
244-
if privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() {
245-
return nil, false, sql.ErrTableAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), tblName)
246-
}
247181
return db.GetTableInsensitiveAsOf(ctx, tblName, asOf)
248182
}
249183

@@ -253,37 +187,7 @@ func (pdb PrivilegedDatabase) GetTableNamesAsOf(ctx *sql.Context, asOf interface
253187
if !ok {
254188
return nil, nil
255189
}
256-
257-
var tablesWithAccess []string
258-
var err error
259-
privSet := pdb.grantTables.UserActivePrivilegeSet(ctx)
260-
261-
checkName := pdb.db.Name()
262-
if adb, ok := pdb.db.(sql.AliasedDatabase); ok {
263-
checkName = adb.AliasedName()
264-
}
265-
266-
dbSet := privSet.Database(checkName)
267-
// If there are no usable privileges for this database then no table is accessible.
268-
if privSet.Count() == 0 && !dbSet.HasPrivileges() {
269-
return nil, nil
270-
}
271-
272-
tblNames, err := db.GetTableNamesAsOf(ctx, asOf)
273-
if err != nil {
274-
return nil, err
275-
}
276-
privSetCount := privSet.Count()
277-
dbSetCount := dbSet.Count()
278-
for _, tblName := range tblNames {
279-
// If the user has any global static privileges, database-level privileges, or table-relevant privileges then a
280-
// table is accessible.
281-
if privSetCount > 0 || dbSetCount > 0 && dbSet.Table(tblName).HasPrivileges() {
282-
tablesWithAccess = append(tablesWithAccess, tblName)
283-
}
284-
}
285-
286-
return tablesWithAccess, nil
190+
return db.GetTableNamesAsOf(ctx, asOf)
287191
}
288192

289193
// CreateTable implements the interface sql.TableCreator.

sql/plan/delete.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -115,27 +115,6 @@ func (p *DeleteFrom) WithChildren(children ...sql.Node) (sql.Node, error) {
115115

116116
// CheckPrivileges implements the interface sql.Node.
117117
func (p *DeleteFrom) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
118-
// TODO: If column values are retrieved then the SELECT privilege is required
119-
// For example: "DELETE FROM table WHERE z > 0"
120-
// We would need SELECT privileges on the "z" column as it's retrieving values
121-
122-
for _, target := range p.GetDeleteTargets() {
123-
deletable, err := GetDeletable(target)
124-
if err != nil {
125-
ctx.GetLogger().Warnf("unable to determine deletable table from delete target: %v", target)
126-
return false
127-
}
128-
129-
subject := sql.PrivilegeCheckSubject{
130-
Database: CheckPrivilegeNameForDatabase(GetDatabase(target)),
131-
Table: deletable.Name(),
132-
}
133-
op := sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Delete)
134-
if opChecker.UserHasPrivileges(ctx, op) == false {
135-
return false
136-
}
137-
}
138-
139118
return true
140119
}
141120

sql/plan/insert.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,7 @@ func (ii *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) {
164164

165165
// CheckPrivileges implements the interface sql.Node.
166166
func (ii *InsertInto) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
167-
subject := sql.PrivilegeCheckSubject{
168-
Database: CheckPrivilegeNameForDatabase(ii.db),
169-
Table: getTableName(ii.Destination),
170-
}
171-
172-
if ii.IsReplace {
173-
return opChecker.UserHasPrivileges(ctx,
174-
sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Insert, sql.PrivilegeType_Delete))
175-
} else {
176-
return opChecker.UserHasPrivileges(ctx,
177-
sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Insert))
178-
}
167+
return true
179168
}
180169

181170
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/plan/resolved_table.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -227,22 +227,7 @@ func (t *ResolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) {
227227

228228
// CheckPrivileges implements the interface sql.Node.
229229
func (t *ResolvedTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
230-
// It is assumed that if we've landed upon this node, then we're doing a SELECT operation. Most other nodes that
231-
// may contain a TableNode will have their own privilege checks, so we should only end up here if the parent
232-
// nodes are things such as indexed access, filters, limits, etc.
233-
if IsDualTable(t) {
234-
return true
235-
}
236-
237-
subject := sql.PrivilegeCheckSubject{
238-
Database: CheckPrivilegeNameForDatabase(t.SqlDatabase),
239-
Table: t.Table.Name(),
240-
}
241-
if subject.Database == sql.InformationSchemaDatabaseName {
242-
return true
243-
}
244-
return opChecker.UserHasPrivileges(ctx,
245-
sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select))
230+
return true
246231
}
247232

248233
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/plan/update.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,7 @@ func (u *Update) WithChildren(children ...sql.Node) (sql.Node, error) {
183183

184184
// CheckPrivileges implements the interface sql.Node.
185185
func (u *Update) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
186-
//TODO: If column values are retrieved then the SELECT privilege is required
187-
// For example: "UPDATE table SET x = y + 1 WHERE z > 0"
188-
// We would need SELECT privileges on both the "y" and "z" columns as they're retrieving values
189-
subject := sql.PrivilegeCheckSubject{
190-
Database: CheckPrivilegeNameForDatabase(u.DB()),
191-
Table: getTableName(u.Child),
192-
}
193-
// TODO: this needs a real database, fix it
194-
return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Update))
186+
return true
195187
}
196188

197189
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/planbuilder/auth.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package planbuilder
16+
17+
import (
18+
"fmt"
19+
"strings"
20+
21+
"github.com/dolthub/vitess/go/mysql"
22+
ast "github.com/dolthub/vitess/go/vt/sqlparser"
23+
24+
"github.com/dolthub/go-mysql-server/sql"
25+
"github.com/dolthub/go-mysql-server/sql/mysql_db"
26+
)
27+
28+
// TODO: doc
29+
var Authorization = func(b *Builder, auth ast.AuthInformation) {
30+
// TODO: expose necessary stuff from Builder to public
31+
ctx := b.ctx
32+
// TODO: database loading shouldn't be done for every call, this needs to be cached for each Parse call in some way
33+
db, err := b.cat.Database(ctx, "mysql")
34+
if err != nil {
35+
b.handleErr(err)
36+
}
37+
mysqlDb, ok := db.(*mysql_db.MySQLDb)
38+
if !ok {
39+
b.handleErr(fmt.Errorf("FOR TESTING: could not load the `mysql` database")) // TODO: Check if this is likely
40+
}
41+
if !mysqlDb.Enabled() {
42+
return
43+
}
44+
// TODO: cache that the user exists
45+
client := ctx.Session.Client()
46+
user := func() *mysql_db.User {
47+
rd := mysqlDb.Reader()
48+
defer rd.Close()
49+
return mysqlDb.GetUser(rd, client.User, client.Address, false)
50+
}()
51+
if user == nil {
52+
b.handleErr(mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%s'", client.User))
53+
}
54+
// TODO: cache for the call
55+
privSet := mysqlDb.UserActivePrivilegeSet(ctx)
56+
57+
var privilegeTypes []sql.PrivilegeType
58+
switch auth.AuthType {
59+
case ast.AuthType_IGNORE:
60+
// This means that authorization is being handled elsewhere (such as a child or parent), and should be ignored here
61+
return
62+
case ast.AuthType_DELETE:
63+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Delete}
64+
case ast.AuthType_INSERT:
65+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Insert}
66+
case ast.AuthType_REPLACE:
67+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Insert, sql.PrivilegeType_Delete}
68+
case ast.AuthType_SELECT:
69+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Select}
70+
case ast.AuthType_UPDATE:
71+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Update}
72+
default:
73+
b.handleErr(fmt.Errorf("FOR TESTING: default case hit for AuthType"))
74+
}
75+
76+
hasPrivileges := true
77+
switch auth.TargetType {
78+
case ast.AuthTargetType_SingleTableIdentifier:
79+
dbName := auth.TargetNames[0]
80+
tableName := auth.TargetNames[1]
81+
if strings.EqualFold(dbName, "information_schema") {
82+
return
83+
}
84+
authCheckDatabaseTableNames(b, privSet, user.User, dbName, tableName)
85+
subject := sql.PrivilegeCheckSubject{
86+
Database: authDatabaseName(ctx, dbName),
87+
Table: tableName,
88+
}
89+
hasPrivileges = mysqlDb.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...))
90+
case ast.AuthTargetType_MultipleTableIdentifiers:
91+
for i := 0; i < len(auth.TargetNames) && hasPrivileges; i += 2 {
92+
dbName := auth.TargetNames[i]
93+
tableName := auth.TargetNames[i+1]
94+
if strings.EqualFold(dbName, "information_schema") {
95+
continue
96+
}
97+
authCheckDatabaseTableNames(b, privSet, user.User, dbName, tableName)
98+
subject := sql.PrivilegeCheckSubject{
99+
Database: authDatabaseName(ctx, dbName),
100+
Table: tableName,
101+
}
102+
hasPrivileges = hasPrivileges && mysqlDb.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...))
103+
}
104+
default:
105+
b.handleErr(fmt.Errorf("FOR TESTING: default case hit for TargetType"))
106+
}
107+
108+
if !hasPrivileges {
109+
b.handleErr(sql.ErrPrivilegeCheckFailed.New(user.UserHostToString("'")))
110+
}
111+
}
112+
113+
// authDatabaseName uses the current database from the context if a database is not specified, otherwise it returns the
114+
// given database name.
115+
func authDatabaseName(ctx *sql.Context, dbName string) string {
116+
if len(dbName) == 0 {
117+
return ctx.GetCurrentDatabase()
118+
}
119+
return dbName
120+
}
121+
122+
// authCheckDatabaseTableNames errors if the user does not have access to the database or table in any capacity,
123+
// regardless of the command.
124+
func authCheckDatabaseTableNames(b *Builder, privSet mysql_db.PrivilegeSet, userName string, dbName string, tableName string) {
125+
dbSet := privSet.Database(dbName)
126+
// If there are no usable privileges for this database then the table is inaccessible.
127+
if privSet.Count() == 0 && !dbSet.HasPrivileges() {
128+
b.handleErr(sql.ErrDatabaseAccessDeniedForUser.New(userName, dbName))
129+
}
130+
tblSet := dbSet.Table(tableName)
131+
// If the user has no global static privileges, database-level privileges, or table-relevant privileges then the
132+
// table is not accessible.
133+
if privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() {
134+
b.handleErr(sql.ErrTableAccessDeniedForUser.New(userName, tableName))
135+
}
136+
}

sql/planbuilder/dml.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
3434
sql.IncrementStatusVariable(b.ctx, "Com_insert", 1)
3535
b.qFlags.Set(sql.QFlagInsert)
3636

37+
Authorization(b, i.Auth)
3738
if i.With != nil {
3839
inScope = b.buildWith(inScope, i.With)
3940
}

0 commit comments

Comments
 (0)