Skip to content

Commit ac750a4

Browse files
committed
POC planbuilder authorization
1 parent c5725b1 commit ac750a4

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,6 @@ require (
4343
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
4444
)
4545

46+
replace github.com/dolthub/vitess => ../vitess
47+
4648
go 1.22.2

sql/planbuilder/auth.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package planbuilder
16+
17+
import (
18+
"fmt"
19+
"github.com/dolthub/go-mysql-server/sql"
20+
"github.com/dolthub/go-mysql-server/sql/mysql_db"
21+
"github.com/dolthub/vitess/go/mysql"
22+
ast "github.com/dolthub/vitess/go/vt/sqlparser"
23+
)
24+
25+
// TODO: doc
26+
var Authorization = func(b *Builder, auth ast.AuthInformation) {
27+
// TODO: expose necessary stuff from Builder to public
28+
ctx := b.ctx
29+
// TODO: database loading shouldn't be done for every call, this needs to be cached for each Parse call in some way
30+
db, err := b.cat.Database(ctx, "mysql")
31+
if err != nil {
32+
b.handleErr(err)
33+
}
34+
mysqlDb, ok := db.(*mysql_db.MySQLDb)
35+
if !ok {
36+
b.handleErr(fmt.Errorf("FOR TESTING: could not load the `mysql` database")) // Check if this is likely
37+
}
38+
if !mysqlDb.Enabled() {
39+
return
40+
}
41+
client := ctx.Session.Client()
42+
user := func() *mysql_db.User {
43+
rd := mysqlDb.Reader()
44+
defer rd.Close()
45+
return mysqlDb.GetUser(rd, client.User, client.Address, false)
46+
}()
47+
if user == nil {
48+
// TODO: Builder should have a HandleAuthErr
49+
b.handleErr(mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", ctx.Session.Client().User))
50+
}
51+
52+
var privilegeTypes []sql.PrivilegeType
53+
switch auth.AuthType {
54+
case ast.AuthType_IGNORE:
55+
// This means that authorization is being handled elsewhere, and should be ignored here
56+
return
57+
case ast.AuthType_DELETE:
58+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Delete}
59+
case ast.AuthType_INSERT:
60+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Insert}
61+
case ast.AuthType_REPLACE:
62+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Insert, sql.PrivilegeType_Delete}
63+
case ast.AuthType_SELECT:
64+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Select}
65+
case ast.AuthType_UPDATE:
66+
privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Update}
67+
default:
68+
b.handleErr(fmt.Errorf("FOR TESTING: default case hit for AuthType"))
69+
}
70+
71+
hasPrivileges := true
72+
switch auth.TargetType {
73+
case ast.AuthTargetType_SingleTableIdentifier:
74+
subject := sql.PrivilegeCheckSubject{
75+
Database: authDatabaseName(ctx, auth.TargetNames[0]),
76+
Table: auth.TargetNames[1],
77+
}
78+
hasPrivileges = mysqlDb.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...))
79+
case ast.AuthTargetType_MultipleTableIdentifiers:
80+
for i := 0; i < len(auth.TargetNames); i += 2 {
81+
subject := sql.PrivilegeCheckSubject{
82+
Database: authDatabaseName(ctx, auth.TargetNames[i]),
83+
Table: auth.TargetNames[i+1],
84+
}
85+
hasPrivileges = hasPrivileges && mysqlDb.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...))
86+
}
87+
default:
88+
b.handleErr(fmt.Errorf("FOR TESTING: default case hit for TargetType"))
89+
}
90+
91+
if !hasPrivileges {
92+
b.handleErr(sql.ErrPrivilegeCheckFailed.New(user.UserHostToString("'")))
93+
}
94+
}
95+
96+
// authDatabaseName uses the current database from the context if a database is not specified, otherwise it returns the
97+
// given database name.
98+
func authDatabaseName(ctx *sql.Context, dbName string) string {
99+
if len(dbName) == 0 {
100+
return ctx.GetCurrentDatabase()
101+
}
102+
return dbName
103+
}

sql/planbuilder/dml.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
152152
checks := b.loadChecksFromTable(destScope, rt.Table)
153153
outScope.node = ins.WithChecks(checks)
154154
}
155+
Authorization(b, i.Auth)
155156

156157
return
157158
}
@@ -475,6 +476,7 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
475476
del.RefsSingleRel = !outScope.refsSubquery
476477
del.IsProcNested = b.ProcCtx().DbName != ""
477478
outScope.node = del
479+
Authorization(b, d.Auth)
478480
return
479481
}
480482

sql/planbuilder/from.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ func (b *Builder) buildDataSource(inScope *scope, te ast.TableExpr) (outScope *s
383383
default:
384384
b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(te)))
385385
}
386+
Authorization(b, t.Auth)
386387

387388
case *ast.TableFuncExpr:
388389
return b.buildTableFunc(inScope, t)

0 commit comments

Comments
 (0)