Skip to content

Commit 3f4374b

Browse files
committed
add helper
1 parent 6f4e157 commit 3f4374b

File tree

11 files changed

+334
-68
lines changed

11 files changed

+334
-68
lines changed

pkg/client/columns.go

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,25 @@ func (c *Client) GrantColumnPrivilege(ctx context.Context, table string, column
9898
privileges = []string{strings.ToUpper(privilege)}
9999
}
100100

101+
escapedTable, err := escapeMySQLIdent(table)
102+
if err != nil {
103+
return err
104+
}
105+
escapedColumn, err := escapeMySQLIdent(column)
106+
if err != nil {
107+
return err
108+
}
109+
101110
var privilegeClauses []string
102111
for _, priv := range privileges {
103-
privilegeClauses = append(privilegeClauses, fmt.Sprintf("%s (%s)", priv, column))
112+
privilegeClauses = append(privilegeClauses, fmt.Sprintf("%s (%s)", priv, escapedColumn))
104113
}
105114
privilegesSQL := strings.Join(privilegeClauses, ", ")
106115

107-
query := fmt.Sprintf("GRANT %s ON %s TO '%s'", privilegesSQL, table, userGrant)
116+
query := fmt.Sprintf("GRANT %s ON %s TO '%s'", privilegesSQL, escapedTable, userGrant)
108117

109-
_, err := c.db.ExecContext(ctx, query)
110-
return err
118+
_ = c.db.MustExec(query)
119+
return nil
111120
}
112121

113122
func (c *Client) RevokeColumnPrivilege(ctx context.Context, table string, column string, user string, privilege string) error {
@@ -124,14 +133,23 @@ func (c *Client) RevokeColumnPrivilege(ctx context.Context, table string, column
124133
privileges = []string{strings.ToUpper(privilege)}
125134
}
126135

136+
escapedTable, err := escapeMySQLIdent(table)
137+
if err != nil {
138+
return err
139+
}
140+
escapedColumn, err := escapeMySQLIdent(column)
141+
if err != nil {
142+
return err
143+
}
144+
127145
var privilegeClauses []string
128146
for _, priv := range privileges {
129-
privilegeClauses = append(privilegeClauses, fmt.Sprintf("%s (%s)", priv, column))
147+
privilegeClauses = append(privilegeClauses, fmt.Sprintf("%s (%s)", priv, escapedColumn))
130148
}
131149
privilegesSQL := strings.Join(privilegeClauses, ", ")
132150

133-
query := fmt.Sprintf("REVOKE %s ON %s FROM '%s'", privilegesSQL, table, userRevoke)
151+
query := fmt.Sprintf("REVOKE %s ON %s FROM '%s'", privilegesSQL, escapedTable, userRevoke)
134152

135-
_, err := c.db.ExecContext(ctx, query)
136-
return err
153+
_ = c.db.MustExec(query)
154+
return nil
137155
}

pkg/client/databases.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"fmt"
56
"strconv"
67
"strings"
78

@@ -71,3 +72,53 @@ func (c *Client) ListDatabases(ctx context.Context, pager *Pager) ([]*DbModel, s
7172

7273
return ret, nextPageToken, nil
7374
}
75+
76+
func (c *Client) GrantDatabasePrivilege(ctx context.Context, database string, user string, privilege string) error {
77+
userSplit := strings.Split(user, "@")
78+
if len(userSplit) != 2 {
79+
return fmt.Errorf("invalid user format, expected user@host")
80+
}
81+
userEsc, err := escapeMySQLUserHost(userSplit[0])
82+
if err != nil {
83+
return err
84+
}
85+
hostEsc, err := escapeMySQLUserHost(userSplit[1])
86+
if err != nil {
87+
return err
88+
}
89+
userGrant := fmt.Sprintf("'%s'@'%s'", userEsc, hostEsc)
90+
91+
escapedDB, err := escapeMySQLIdent(database)
92+
if err != nil {
93+
return err
94+
}
95+
96+
query := fmt.Sprintf("GRANT %s ON %s.* TO %s", strings.ToUpper(privilege), escapedDB, userGrant)
97+
_ = c.db.MustExec(query)
98+
return nil
99+
}
100+
101+
func (c *Client) RevokeDatabasePrivilege(ctx context.Context, database string, user string, privilege string) error {
102+
userSplit := strings.Split(user, "@")
103+
if len(userSplit) != 2 {
104+
return fmt.Errorf("invalid user format, expected user@host")
105+
}
106+
userEsc, err := escapeMySQLUserHost(userSplit[0])
107+
if err != nil {
108+
return err
109+
}
110+
hostEsc, err := escapeMySQLUserHost(userSplit[1])
111+
if err != nil {
112+
return err
113+
}
114+
userRevoke := fmt.Sprintf("'%s'@'%s'", userEsc, hostEsc)
115+
116+
escapedDB, err := escapeMySQLIdent(database)
117+
if err != nil {
118+
return err
119+
}
120+
121+
query := fmt.Sprintf("REVOKE %s ON %s.* FROM %s", strings.ToUpper(privilege), escapedDB, userRevoke)
122+
_ = c.db.MustExec(query)
123+
return nil
124+
}

pkg/client/helper.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package client
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
"strings"
7+
)
8+
9+
// Helper for identifiers (tables, columns, databases)
10+
var validIdent = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
11+
12+
func escapeMySQLIdent(ident string) (string, error) {
13+
parts := strings.Split(ident, ".")
14+
for i, part := range parts {
15+
if !validIdent.MatchString(part) {
16+
return "", fmt.Errorf("invalid identifier: %s", ident)
17+
}
18+
parts[i] = "`" + strings.ReplaceAll(part, "`", "``") + "`"
19+
}
20+
return strings.Join(parts, "."), nil
21+
}
22+
23+
// Helper for user/host
24+
var validUserHost = regexp.MustCompile(`^[a-zA-Z0-9_%\\.\\-]+$`)
25+
26+
func escapeMySQLUserHost(ident string) (string, error) {
27+
if !validUserHost.MatchString(ident) {
28+
return "", fmt.Errorf("invalid user/host: %s", ident)
29+
}
30+
return ident, nil
31+
}

pkg/client/roles.go

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,22 @@ func (c *Client) GrantRolePrivilege(ctx context.Context, role, user, privilege s
1717
return fmt.Errorf("invalid user format: %s", user)
1818
}
1919

20-
roleUser := roleParts[0]
21-
roleHost := roleParts[1]
22-
23-
targetUser := userParts[0]
24-
targetHost := userParts[1]
20+
roleUser, err := escapeMySQLUserHost(roleParts[0])
21+
if err != nil {
22+
return err
23+
}
24+
roleHost, err := escapeMySQLUserHost(roleParts[1])
25+
if err != nil {
26+
return err
27+
}
28+
targetUser, err := escapeMySQLUserHost(userParts[0])
29+
if err != nil {
30+
return err
31+
}
32+
targetHost, err := escapeMySQLUserHost(userParts[1])
33+
if err != nil {
34+
return err
35+
}
2536

2637
var grantStmt string
2738
switch privilege {
@@ -37,9 +48,10 @@ func (c *Client) GrantRolePrivilege(ctx context.Context, role, user, privilege s
3748
return fmt.Errorf("unknown privilege: %s", privilege)
3849
}
3950

40-
_, err := c.db.ExecContext(ctx, grantStmt)
41-
return err
51+
_ = c.db.MustExec(grantStmt)
52+
return nil
4253
}
54+
4355
func (c *Client) RevokeRolePrivilege(ctx context.Context, role, user, privilege string) error {
4456
roleParts := strings.Split(role, "@")
4557
if len(roleParts) != 2 {
@@ -51,11 +63,22 @@ func (c *Client) RevokeRolePrivilege(ctx context.Context, role, user, privilege
5163
return fmt.Errorf("invalid user format: %s", user)
5264
}
5365

54-
roleUser := roleParts[0]
55-
roleHost := roleParts[1]
56-
57-
targetUser := userParts[0]
58-
targetHost := userParts[1]
66+
roleUser, err := escapeMySQLUserHost(roleParts[0])
67+
if err != nil {
68+
return err
69+
}
70+
roleHost, err := escapeMySQLUserHost(roleParts[1])
71+
if err != nil {
72+
return err
73+
}
74+
targetUser, err := escapeMySQLUserHost(userParts[0])
75+
if err != nil {
76+
return err
77+
}
78+
targetHost, err := escapeMySQLUserHost(userParts[1])
79+
if err != nil {
80+
return err
81+
}
5982

6083
var revokeStmt string
6184
switch privilege {
@@ -67,6 +90,6 @@ func (c *Client) RevokeRolePrivilege(ctx context.Context, role, user, privilege
6790
return fmt.Errorf("unknown privilege: %s", privilege)
6891
}
6992

70-
_, err := c.db.ExecContext(ctx, revokeStmt)
71-
return err
93+
_ = c.db.MustExec(revokeStmt)
94+
return nil
7295
}

pkg/client/routines.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,34 +89,69 @@ func (c *Client) GrantRoutinePrivilege(ctx context.Context, privilege string, sc
8989
return err
9090
}
9191

92+
schemaEsc, err := escapeMySQLIdent(schema)
93+
if err != nil {
94+
return err
95+
}
96+
routineNameEsc, err := escapeMySQLIdent(routineName)
97+
if err != nil {
98+
return err
99+
}
100+
92101
userSplit := strings.Split(user, "@")
93102
if len(userSplit) != 2 {
94103
return fmt.Errorf("invalid user format, expected user@host")
95104
}
96-
userGrant := fmt.Sprintf("%s'@'%s", userSplit[0], userSplit[1])
105+
userEsc, err := escapeMySQLUserHost(userSplit[0])
106+
if err != nil {
107+
return err
108+
}
109+
hostEsc, err := escapeMySQLUserHost(userSplit[1])
110+
if err != nil {
111+
return err
112+
}
113+
userGrant := fmt.Sprintf("'%s'@'%s'", userEsc, hostEsc)
97114

98-
query := fmt.Sprintf("GRANT %s ON %s %s.%s TO '%s'",
99-
privilege, strings.ToUpper(routineType), schema, routineName, userGrant)
115+
query := fmt.Sprintf("GRANT %s ON %s %s.%s TO %s",
116+
privilege, strings.ToUpper(routineType), schemaEsc, routineNameEsc, userGrant)
100117

101-
_, err = c.db.ExecContext(ctx, query)
102-
return err
118+
_ = c.db.MustExec(query)
119+
return nil
103120
}
121+
104122
func (c *Client) RevokeRoutinePrivilege(ctx context.Context, privilege string, schema string, routineName string, user string) error {
105123
routineType, err := c.GetRoutineType(ctx, schema, routineName)
106124
if err != nil {
107125
return err
108126
}
109127

128+
schemaEsc, err := escapeMySQLIdent(schema)
129+
if err != nil {
130+
return err
131+
}
132+
routineNameEsc, err := escapeMySQLIdent(routineName)
133+
if err != nil {
134+
return err
135+
}
136+
110137
userSplit := strings.Split(user, "@")
111138
if len(userSplit) != 2 {
112139
return fmt.Errorf("invalid user format, expected user@host")
113140
}
114-
userRevoke := fmt.Sprintf("%s'@'%s", userSplit[0], userSplit[1])
141+
userEsc, err := escapeMySQLUserHost(userSplit[0])
142+
if err != nil {
143+
return err
144+
}
145+
hostEsc, err := escapeMySQLUserHost(userSplit[1])
146+
if err != nil {
147+
return err
148+
}
149+
userRevoke := fmt.Sprintf("'%s'@'%s'", userEsc, hostEsc)
115150

116-
query := fmt.Sprintf("REVOKE %s ON %s %s.%s FROM '%s'",
117-
privilege, strings.ToUpper(routineType), schema, routineName, userRevoke)
118-
_, err = c.db.ExecContext(ctx, query)
119-
return err
151+
query := fmt.Sprintf("REVOKE %s ON %s %s.%s FROM %s",
152+
privilege, strings.ToUpper(routineType), schemaEsc, routineNameEsc, userRevoke)
153+
_ = c.db.MustExec(query)
154+
return nil
120155
}
121156

122157
func (c *Client) GetRoutineType(ctx context.Context, schema, routineName string) (string, error) {

pkg/client/servers.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"fmt"
7+
"strings"
78

89
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
910
)
@@ -37,3 +38,43 @@ func (c *Client) ExecContext(ctx context.Context, query string) (sql.Result, err
3738
}
3839
return c.db.ExecContext(ctx, query)
3940
}
41+
42+
func (c *Client) GrantServerPrivilege(ctx context.Context, user string, privilege string) error {
43+
userSplit := strings.Split(user, "@")
44+
if len(userSplit) != 2 {
45+
return fmt.Errorf("invalid user format, expected user@host")
46+
}
47+
userEsc, err := escapeMySQLUserHost(userSplit[0])
48+
if err != nil {
49+
return err
50+
}
51+
hostEsc, err := escapeMySQLUserHost(userSplit[1])
52+
if err != nil {
53+
return err
54+
}
55+
userGrant := fmt.Sprintf("'%s'@'%s'", userEsc, hostEsc)
56+
57+
query := fmt.Sprintf("GRANT %s ON *.* TO %s", strings.ToUpper(privilege), userGrant)
58+
_ = c.db.MustExec(query)
59+
return nil
60+
}
61+
62+
func (c *Client) RevokeServerPrivilege(ctx context.Context, user string, privilege string) error {
63+
userSplit := strings.Split(user, "@")
64+
if len(userSplit) != 2 {
65+
return fmt.Errorf("invalid user format, expected user@host")
66+
}
67+
userEsc, err := escapeMySQLUserHost(userSplit[0])
68+
if err != nil {
69+
return err
70+
}
71+
hostEsc, err := escapeMySQLUserHost(userSplit[1])
72+
if err != nil {
73+
return err
74+
}
75+
userRevoke := fmt.Sprintf("'%s'@'%s'", userEsc, hostEsc)
76+
77+
query := fmt.Sprintf("REVOKE %s ON *.* FROM %s", strings.ToUpper(privilege), userRevoke)
78+
_ = c.db.MustExec(query)
79+
return nil
80+
}

0 commit comments

Comments
 (0)