@@ -89,34 +89,69 @@ func (c *Client) GrantRoutinePrivilege(ctx context.Context, privilege string, sc
8989 return err
9090 }
9191
92+ schemaEsc , err := escapeMySQLIdent (schema )
93+ if err != nil {
94+ return err
95+ }
96+ routineNameEsc , err := escapeMySQLIdent (routineName )
97+ if err != nil {
98+ return err
99+ }
100+
92101 userSplit := strings .Split (user , "@" )
93102 if len (userSplit ) != 2 {
94103 return fmt .Errorf ("invalid user format, expected user@host" )
95104 }
96- userGrant := fmt .Sprintf ("%s'@'%s" , userSplit [0 ], userSplit [1 ])
105+ userEsc , err := escapeMySQLUserHost (userSplit [0 ])
106+ if err != nil {
107+ return err
108+ }
109+ hostEsc , err := escapeMySQLUserHost (userSplit [1 ])
110+ if err != nil {
111+ return err
112+ }
113+ userGrant := fmt .Sprintf ("'%s'@'%s'" , userEsc , hostEsc )
97114
98- query := fmt .Sprintf ("GRANT %s ON %s %s.%s TO '%s' " ,
99- privilege , strings .ToUpper (routineType ), schema , routineName , userGrant )
115+ query := fmt .Sprintf ("GRANT %s ON %s %s.%s TO %s " ,
116+ privilege , strings .ToUpper (routineType ), schemaEsc , routineNameEsc , userGrant )
100117
101- _ , err = c .db .ExecContext ( ctx , query )
102- return err
118+ _ = c .db .MustExec ( query )
119+ return nil
103120}
121+
104122func (c * Client ) RevokeRoutinePrivilege (ctx context.Context , privilege string , schema string , routineName string , user string ) error {
105123 routineType , err := c .GetRoutineType (ctx , schema , routineName )
106124 if err != nil {
107125 return err
108126 }
109127
128+ schemaEsc , err := escapeMySQLIdent (schema )
129+ if err != nil {
130+ return err
131+ }
132+ routineNameEsc , err := escapeMySQLIdent (routineName )
133+ if err != nil {
134+ return err
135+ }
136+
110137 userSplit := strings .Split (user , "@" )
111138 if len (userSplit ) != 2 {
112139 return fmt .Errorf ("invalid user format, expected user@host" )
113140 }
114- userRevoke := fmt .Sprintf ("%s'@'%s" , userSplit [0 ], userSplit [1 ])
141+ userEsc , err := escapeMySQLUserHost (userSplit [0 ])
142+ if err != nil {
143+ return err
144+ }
145+ hostEsc , err := escapeMySQLUserHost (userSplit [1 ])
146+ if err != nil {
147+ return err
148+ }
149+ userRevoke := fmt .Sprintf ("'%s'@'%s'" , userEsc , hostEsc )
115150
116- query := fmt .Sprintf ("REVOKE %s ON %s %s.%s FROM '%s' " ,
117- privilege , strings .ToUpper (routineType ), schema , routineName , userRevoke )
118- _ , err = c .db .ExecContext ( ctx , query )
119- return err
151+ query := fmt .Sprintf ("REVOKE %s ON %s %s.%s FROM %s " ,
152+ privilege , strings .ToUpper (routineType ), schemaEsc , routineNameEsc , userRevoke )
153+ _ = c .db .MustExec ( query )
154+ return nil
120155}
121156
122157func (c * Client ) GetRoutineType (ctx context.Context , schema , routineName string ) (string , error ) {
0 commit comments