diff --git a/mysql/provider.go b/mysql/provider.go index c7a7c9bf..d3910c1a 100644 --- a/mysql/provider.go +++ b/mysql/provider.go @@ -45,14 +45,15 @@ import ( ) const ( - cleartextPasswords = "cleartext" - nativePasswords = "native" - userNotFoundErrCode = 1133 - unknownUserErrCode = 1396 - azEnvPublic = "public" - azEnvChina = "china" - azEnvGerman = "german" - azEnvUSGovernment = "usgovernment" + cleartextPasswords = "cleartext" + nativePasswords = "native" + userNotFoundErrCode = 1133 + unknownUserErrCode = 1396 + nonExistingGrantErrCode = 1141 + azEnvPublic = "public" + azEnvChina = "china" + azEnvGerman = "german" + azEnvUSGovernment = "usgovernment" ) type OneConnection struct { @@ -890,6 +891,13 @@ func quoteIdentifier(in string) string { return fmt.Sprintf("`%s`", identQuoteReplacer.Replace(in)) } +// quoteRoleName safely quotes role names with backticks and proper escaping. +// It escapes backticks by doubling them (e.g., `name“with`backtick“ becomes `name“with“backtick“). +// Backtick quoting is preferred over single quotes because backslashes don't need escaping. +func quoteRoleName(s string) string { + return fmt.Sprintf("`%s`", strings.ReplaceAll(s, "`", "``")) +} + func serverVersion(db *sql.DB) (*version.Version, error) { var versionString string err := db.QueryRow("SELECT @@GLOBAL.version").Scan(&versionString) diff --git a/mysql/resource_default_roles.go b/mysql/resource_default_roles.go index 19e5e65c..2daf47d4 100644 --- a/mysql/resource_default_roles.go +++ b/mysql/resource_default_roles.go @@ -60,10 +60,15 @@ func checkDefaultRolesSupport(ctx context.Context, meta interface{}) error { func alterUserDefaultRoles(ctx context.Context, db *sql.DB, user, host string, roles []string) error { var stmtSQL string - stmtSQL = fmt.Sprintf("ALTER USER '%s'@'%s' DEFAULT ROLE ", user, host) + // Use formatUserIdentifier for consistent quoting (backtick + escaping) + stmtSQL = fmt.Sprintf("ALTER USER %s DEFAULT ROLE ", formatUserIdentifier(user, host)) if len(roles) > 0 { - stmtSQL += fmt.Sprintf("'%s'", strings.Join(roles, "', '")) + quotedRoles := make([]string, len(roles)) + for i, role := range roles { + quotedRoles[i] = quoteRoleName(role) + } + stmtSQL += fmt.Sprintf("%s", strings.Join(quotedRoles, ", ")) } else { stmtSQL += "NONE" } @@ -191,14 +196,13 @@ func DeleteDefaultRoles(ctx context.Context, d *schema.ResourceData, meta interf } func ImportDefaultRoles(ctx context.Context, d *schema.ResourceData, meta interface{}) ([]*schema.ResourceData, error) { - userHost := strings.SplitN(d.Id(), "@", 2) - - if len(userHost) != 2 { - return nil, fmt.Errorf("wrong ID format %s (expected USER@HOST)", d.Id()) + user, host, err := parseUserHost(d.Id()) + if err != nil { + return nil, err } - d.Set("user", userHost[0]) - d.Set("host", userHost[1]) + d.Set("user", user) + d.Set("host", host) readDiags := ReadDefaultRoles(ctx, d, meta) for _, readDiag := range readDiags { diff --git a/mysql/resource_grant.go b/mysql/resource_grant.go index a008ad58..e2428272 100644 --- a/mysql/resource_grant.go +++ b/mysql/resource_grant.go @@ -70,10 +70,15 @@ func (u UserOrRole) IDString() string { } func (u UserOrRole) SQLString() string { + // If Host is empty, it's a role - use backticks with doubled-backtick escaping + // If Host is not empty, it's a user - use single quotes with doubled-quote escaping if u.Host == "" { - return fmt.Sprintf("'%s'", u.Name) + escapedName := strings.ReplaceAll(u.Name, "`", "``") + return fmt.Sprintf("`%s`", escapedName) } - return fmt.Sprintf("'%s'@'%s'", u.Name, u.Host) + escapedName := strings.ReplaceAll(u.Name, "'", "''") + escapedHost := strings.ReplaceAll(u.Host, "'", "''") + return fmt.Sprintf("'%s'@'%s'", escapedName, escapedHost) } func (u UserOrRole) Equals(other UserOrRole) bool { @@ -278,7 +283,11 @@ func (t *RoleGrant) GrantOption() bool { } func (t *RoleGrant) SQLGrantStatement() string { - stmtSql := fmt.Sprintf("GRANT '%s' TO %s", strings.Join(t.Roles, "', '"), t.UserOrRole.SQLString()) + quotedRoles := make([]string, len(t.Roles)) + for i, role := range t.Roles { + quotedRoles[i] = quoteRoleName(role) + } + stmtSql := fmt.Sprintf("GRANT %s TO %s", strings.Join(quotedRoles, ", "), t.UserOrRole.SQLString()) if t.TLSOption != "" && strings.ToLower(t.TLSOption) != "none" { stmtSql += fmt.Sprintf(" REQUIRE %s", t.TLSOption) } @@ -289,7 +298,11 @@ func (t *RoleGrant) SQLGrantStatement() string { } func (t *RoleGrant) SQLRevokeStatement() string { - return fmt.Sprintf("REVOKE '%s' FROM %s", strings.Join(t.Roles, "', '"), t.UserOrRole.SQLString()) + quotedRoles := make([]string, len(t.Roles)) + for i, role := range t.Roles { + quotedRoles[i] = quoteRoleName(role) + } + return fmt.Sprintf("REVOKE %s FROM %s", strings.Join(quotedRoles, ", "), t.UserOrRole.SQLString()) } func (t *RoleGrant) GetRoles() []string { @@ -638,7 +651,7 @@ func DeleteGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) // Parse the grant from ResourceData grant, diagErr := parseResourceFromData(d) - if err != nil { + if diagErr != nil { return diagErr } @@ -667,24 +680,63 @@ func isNonExistingGrant(err error) bool { } func ImportGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) ([]*schema.ResourceData, error) { - userHostDatabaseTable := strings.Split(strings.TrimSuffix(d.Id(), ";r"), "@") + idWithoutSuffix := strings.TrimSuffix(d.Id(), ";r") + userHostDatabaseTable := strings.Split(idWithoutSuffix, "@") + + // Expected formats: + // - user@host@database@table (4 parts) - no grant option + // - user@host@database@table@ (5 parts with empty last element) - with grant option + // - user@host@database@table;r (4 parts + ;r suffix) - role grant without grant option + // - user@host@database@table@;r (5 parts with empty last element + ;r suffix) - role grant with grant option + // If the username contains @ (e.g., user@domain.com), there will be more parts than expected. + // Join the extra parts at the beginning to reconstruct the username. + // For example: user@domain.com@host@database@table -> parts = ["user", "domain.com", "host", "database", "table"] + + isRoleGrant := strings.HasSuffix(d.Id(), ";r") + + // Check if the ID ends with @ (before ;r), which indicates grant option + // The trailing @ creates an empty string element when splitting + hasTrailingAt := len(userHostDatabaseTable) > 0 && userHostDatabaseTable[len(userHostDatabaseTable)-1] == "" + grantOption := hasTrailingAt + + // Expected number of parts for standard format (without considering embedded @ in username) + // If grant option is present, we have 5 elements (with empty last one) + baseExpectedParts := 4 // user@host@database@table (without grant option) + + var user, host, database, table string + + if len(userHostDatabaseTable) > baseExpectedParts+1 { + // Username contains @ - need to reconstruct it + // The extra parts beyond baseExpectedParts belong to the username + // If grant option is present, there's an extra empty element at the end, so we subtract 1 + extraParts := len(userHostDatabaseTable) - baseExpectedParts + if hasTrailingAt { + extraParts-- + } - if len(userHostDatabaseTable) != 4 && len(userHostDatabaseTable) != 5 { + // The first extraParts+1 elements form the username (all but the last 3 parts which are host, database, table) + numUserParts := extraParts + 1 + user = strings.Join(userHostDatabaseTable[:numUserParts], "@") + host = userHostDatabaseTable[numUserParts] + database = userHostDatabaseTable[numUserParts+1] + table = userHostDatabaseTable[numUserParts+2] + } else if len(userHostDatabaseTable) == baseExpectedParts || len(userHostDatabaseTable) == baseExpectedParts+1 { + // Standard case - no embedded @ in username + user = userHostDatabaseTable[0] + host = userHostDatabaseTable[1] + database = userHostDatabaseTable[2] + table = userHostDatabaseTable[3] + } else { return nil, fmt.Errorf("wrong ID format %s - expected user@host@database@table (and optionally ending @ to signify grant option) where some parts can be empty)", d.Id()) } - user := userHostDatabaseTable[0] - host := userHostDatabaseTable[1] - database := userHostDatabaseTable[2] - table := userHostDatabaseTable[3] - grantOption := len(userHostDatabaseTable) == 5 userOrRole := UserOrRole{ Name: user, Host: host, } var desiredGrant MySQLGrant - if strings.HasSuffix(d.Id(), ";r") { + if isRoleGrant { desiredGrant = &RoleGrant{ UserOrRole: userOrRole, Grant: grantOption, @@ -724,7 +776,7 @@ func ImportGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) } } - return nil, fmt.Errorf("failed to find the grant to import: %v -- found %#v", userHostDatabaseTable, grants) + return nil, fmt.Errorf("failed to find the grant to import: user=%s host=%s database=%s table=%s -- found %#v", user, host, database, table, grants) } // setDataFromGrant copies the values from MySQLGrant to the schema.ResourceData @@ -768,12 +820,19 @@ func setDataFromGrant(grant MySQLGrant, d *schema.ResourceData) *schema.Resource d.Set("database", tablePrivGrant.Database) } - // This is a bit of a hack, since we don't have a way to distingush between users and roles - // from the grant itself. We can only infer it from the schema. userOrRole := grant.GetUserOrRole() - if d.Get("role") != "" { + if _, ok := grant.(*RoleGrant); ok { + // This is a role grant - set the role attribute + d.Set("role", userOrRole.Name) + d.Set("user", "") + d.Set("host", "") + } else if d.Get("role") != "" { + // Role was specified in config d.Set("role", userOrRole.Name) + d.Set("user", "") + d.Set("host", "") } else { + // User grant d.Set("user", userOrRole.Name) d.Set("host", userOrRole.Host) } @@ -840,24 +899,44 @@ func getMatchingGrant(ctx context.Context, db *sql.DB, desiredGrant MySQLGrant) } var ( - kUserOrRoleRegex = regexp.MustCompile("['`]?([^'`]+)['`]?(?:@['`]?([^'`]+)['`]?)?") + // kUserOrRoleRegex matches user/role names with proper handling of backslash escape sequences + // and doubled single quotes (SQL standard escaping for single quotes in identifiers). + // Pattern handles: unquoted names, single-quoted names, and backtick-quoted names. + // For quoted names, it properly captures backslash-escaped characters (e.g., \' or \\) and doubled single quotes (''). + // Importantly, @ is allowed inside quoted usernames to support GCP IAM email addresses like 'user@example.com'@'%'. + // Group 1: username (quoted or unquoted) + // Group 2: host (quoted or unquoted, optional) + kUserOrRoleRegex = regexp.MustCompile("^((?:'(?:[^'\\\\]|\\\\.|'')*'|`(?:[^`\\\\]|\\\\.|``)*`|(?:[^'\"`@\\\\]|\\\\.)+))(?:@((?:'(?:[^'\\\\]|\\\\.|'')*'|`(?:[^`\\\\]|\\\\.|``)*`|(?:[^'\"`\\\\]|\\\\.)+)))?$") ) +// stripQuotes removes outer matching quotes (single quotes or backticks) from a string +func stripQuotes(s string) string { + if len(s) >= 2 { + if (s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '`' && s[len(s)-1] == '`') { + return s[1 : len(s)-1] + } + } + return s +} + func parseUserOrRoleFromRow(userOrRoleStr string) (*UserOrRole, error) { userHostMatches := kUserOrRoleRegex.FindStringSubmatch(userOrRoleStr) - if len(userHostMatches) == 3 { - return &UserOrRole{ - Name: userHostMatches[1], - Host: userHostMatches[2], - }, nil - } else if len(userHostMatches) == 2 { + // Group structure with the new regex: + // [0] full match, [1] username (may include quotes), [2] host (may include quotes, optional) + if len(userHostMatches) >= 2 && userHostMatches[1] != "" { + // Strip outer quotes and unescape + name := unescapeRoleName(stripQuotes(userHostMatches[1])) + host := "%" + // Has host (group 2) + if len(userHostMatches) >= 3 && userHostMatches[2] != "" { + host = unescapeRoleName(stripQuotes(userHostMatches[2])) + } return &UserOrRole{ - Name: userHostMatches[1], - Host: "%", + Name: name, + Host: host, }, nil - } else { - return nil, fmt.Errorf("failed to parse user or role portion of grant statement: %s", userOrRoleStr) } + return nil, fmt.Errorf("failed to parse user or role portion of grant statement: %s", userOrRoleStr) } var ( @@ -961,7 +1040,13 @@ func parseGrantFromRow(grantStr string) (MySQLGrant, error) { roles := make([]string, len(rolesStart)) for i, role := range rolesStart { - roles[i] = strings.Trim(role, "`@%\" ") + role = strings.TrimSpace(role) + // Remove outer quotes if present + if len(role) >= 2 && ((role[0] == '`' && role[len(role)-1] == '`') || + (role[0] == '\'' && role[len(role)-1] == '\'')) { + role = role[1 : len(role)-1] + } + roles[i] = unescapeRoleName(role) } userOrRole, err := parseUserOrRoleFromRow(roleMatches[2]) @@ -983,9 +1068,42 @@ func parseGrantFromRow(grantStr string) (MySQLGrant, error) { } } +// unescapeRoleName reverses the escaping done by quoteRoleName in provider.go. +// It handles backslash-escaping (from Terraform import double-escaping), doubled backticks +// (for backward compatibility with backtick-quoted identifiers), and doubled single quotes +// (for single-quoted identifiers used in GRANT/CREATE ROLE statements). +func unescapeRoleName(s string) string { + // Unescape doubled backslashes first (handles Terraform import double-escaping) + s = strings.ReplaceAll(s, "\\\\", "\\") + // Handle doubled backticks (for backtick-quoted identifiers) + s = strings.ReplaceAll(s, "``", "`") + // Handle doubled single quotes (for backward compatibility with single-quoted strings) + s = strings.ReplaceAll(s, "''", "'") + return s +} + func showUserGrants(ctx context.Context, db *sql.DB, userOrRole UserOrRole) ([]MySQLGrant, error) { grants := []MySQLGrant{} + // Check if this is a cloudiamgroup user on GCP CloudSQL (version ends with "-google") + // GCP CloudSQL uses cloudiamgroup as a placeholder role, but SHOW GRANTS returns an error + // We only fetch the server version if needed (when user is cloudiamgroup) + isGCPCloudSQL := false + if userOrRole.Name == "cloudiamgroup" { + serverVersion, err := serverVersionString(db) + if err != nil { + log.Printf("[WARN] Failed to get server version for cloudiamgroup check: %v", err) + } else { + isGCPCloudSQL = strings.HasSuffix(serverVersion, "-google") + } + } + + // On GCP CloudSQL, cloudiamgroup grants are not real and should be ignored + if isGCPCloudSQL { + log.Printf("[DEBUG] Detected GCP CloudSQL (version ends with -google), skipping cloudiamgroup grants") + return grants, nil + } + sqlStatement := fmt.Sprintf("SHOW GRANTS FOR %s", userOrRole.SQLString()) log.Printf("[DEBUG] SQL to show grants: %s", sqlStatement) rows, err := db.QueryContext(ctx, sqlStatement) diff --git a/mysql/resource_role.go b/mysql/resource_role.go index 5a97e955..bd73a87a 100644 --- a/mysql/resource_role.go +++ b/mysql/resource_role.go @@ -14,6 +14,9 @@ func resourceRole() *schema.Resource { CreateContext: CreateRole, ReadContext: ReadRole, DeleteContext: DeleteRole, + Importer: &schema.ResourceImporter{ + StateContext: schema.ImportStatePassthroughContext, + }, Schema: map[string]*schema.Schema{ "name": { @@ -32,8 +35,9 @@ func CreateRole(ctx context.Context, d *schema.ResourceData, meta interface{}) d } roleName := d.Get("name").(string) + log.Printf("[DEBUG] CreateRole: roleName=%q, escaped=%q", roleName, quoteRoleName(roleName)) - sql := fmt.Sprintf("CREATE ROLE '%s'", roleName) + sql := fmt.Sprintf("CREATE ROLE %s", quoteRoleName(roleName)) log.Printf("[DEBUG] SQL: %s", sql) _, err = db.ExecContext(ctx, sql) @@ -52,17 +56,29 @@ func ReadRole(ctx context.Context, d *schema.ResourceData, meta interface{}) dia return diag.FromErr(err) } - sql := fmt.Sprintf("SHOW GRANTS FOR '%s'", d.Id()) + roleName := unescapeRoleName(d.Id()) + log.Printf("[DEBUG] ReadRole: d.Id()=%q, unescaped=%q, escaped=%q", d.Id(), roleName, quoteRoleName(roleName)) + sql := fmt.Sprintf("SHOW GRANTS FOR %s", quoteRoleName(roleName)) log.Printf("[DEBUG] SQL: %s", sql) - _, err = db.ExecContext(ctx, sql) + rows, err := db.QueryContext(ctx, sql) if err != nil { - log.Printf("[WARN] Role (%s) not found; removing from state", d.Id()) + errorNumber := mysqlErrorNumber(err) + if errorNumber == unknownUserErrCode || errorNumber == userNotFoundErrCode || errorNumber == nonExistingGrantErrCode { + d.SetId("") + return nil + } + return diag.Errorf("error reading role: %s", err) + } + defer rows.Close() + + if !rows.Next() { d.SetId("") return nil } - d.Set("name", d.Id()) + d.SetId(roleName) + d.Set("name", roleName) return nil } @@ -73,7 +89,7 @@ func DeleteRole(ctx context.Context, d *schema.ResourceData, meta interface{}) d return diag.FromErr(err) } - sql := fmt.Sprintf("DROP ROLE '%s'", d.Get("name").(string)) + sql := fmt.Sprintf("DROP ROLE %s", quoteRoleName(d.Get("name").(string))) log.Printf("[DEBUG] SQL: %s", sql) _, err = db.ExecContext(ctx, sql) diff --git a/mysql/resource_role_test.go b/mysql/resource_role_test.go index 631d1690..b7f9e6f1 100644 --- a/mysql/resource_role_test.go +++ b/mysql/resource_role_test.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + "regexp" + "strings" "testing" "github.com/hashicorp/go-version" @@ -45,6 +47,12 @@ func TestAccRole_basic(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "name", roleName), ), }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, }, }) } @@ -72,7 +80,7 @@ func testAccRoleExists(roleName string) resource.TestCheckFunc { } func testAccGetRoleGrantCount(roleName string, db *sql.DB) (int, error) { - rows, err := db.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'", roleName)) + rows, err := db.Query(fmt.Sprintf("SHOW GRANTS FOR %s", quoteRoleName(roleName))) if err != nil { return 0, err } @@ -105,9 +113,762 @@ func testAccRoleCheckDestroy(roleName string) resource.TestCheckFunc { } func testAccRoleConfigBasic(roleName string) string { + // Escape backslashes first, then double quotes for HCL string literal + escaped := strings.ReplaceAll(roleName, `\`, `\\`) + escaped = strings.ReplaceAll(escaped, `"`, `\"`) return fmt.Sprintf(` resource "mysql_role" "test" { name = "%s" } +`, escaped) +} + +func testAccRoleConfigDifferent(roleName string) string { + return fmt.Sprintf(` +resource "mysql_role" "different" { + name = "%s" +} `, roleName) } + +func testAccRoleConfigMultiple(roleName1, roleName2 string) string { + return fmt.Sprintf(` +resource "mysql_role" "test1" { + name = "%s" +} + +resource "mysql_role" "test2" { + name = "%s" +} +`, roleName1, roleName2) +} + +func TestAccRole_importBasic(t *testing.T) { + roleName := "tf-test-role-import-basic" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importSpecialChars(t *testing.T) { + roleName := "tf-test-role@#$%" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importAndPlan(t *testing.T) { + roleName := "tf-test-role-import-plan" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + { + Config: testAccRoleConfigBasic(roleName), + PlanOnly: true, + ExpectNonEmptyPlan: false, + }, + }, + }) +} + +func TestAccRole_importAndDestroy(t *testing.T) { + roleName := "tf-test-role-import-destroy" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importWithDifferentResourceName(t *testing.T) { + roleName := "tf-test-role-diff-resource" + resourceName := "mysql_role.different" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigDifferent(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importMultipleRoles(t *testing.T) { + roleName1 := "tf-test-role-multi1" + roleName2 := "tf-test-role-multi2" + resourceName1 := "mysql_role.test1" + resourceName2 := "mysql_role.test2" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: resource.ComposeTestCheckFunc( + testAccRoleCheckDestroy(roleName1), + testAccRoleCheckDestroy(roleName2), + ), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigMultiple(roleName1, roleName2), + }, + { + ResourceName: resourceName1, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName1, + }, + { + ResourceName: resourceName2, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName2, + }, + }, + }) +} + +func TestAccRole_importNonExistent(t *testing.T) { + roleName := "tf-test-role-nonexistent" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + testAccPreCheckSkipMariaDB(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + Steps: []resource.TestStep{ + { + ResourceName: "mysql_role.test", + ImportState: true, + ImportStateId: roleName, + ExpectError: regexp.MustCompile("Cannot import non-existent remote object"), + }, + }, + }) +} +func TestAccRole_importWithGrants(t *testing.T) { + roleName := "tf-test-role-with-grants" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + testAccPreCheckSkipMariaDB(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigWithGrants(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importAndUpdate(t *testing.T) { + roleName1 := "tf-test-role-update1" + roleName2 := "tf-test-role-update2" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: resource.ComposeTestCheckFunc( + testAccRoleCheckDestroy(roleName1), + testAccRoleCheckDestroy(roleName2), + ), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName1), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName1, + }, + { + Config: testAccRoleConfigBasic(roleName2), + Check: resource.ComposeTestCheckFunc( + testAccRoleExists(roleName2), + resource.TestCheckResourceAttr(resourceName, "name", roleName2), + ), + }, + }, + }) +} + +func TestAccRole_importConcurrency(t *testing.T) { + roleName1 := "tf-test-role-conc1" + roleName2 := "tf-test-role-conc2" + resourceName1 := "mysql_role.test1" + resourceName2 := "mysql_role.test2" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: resource.ComposeTestCheckFunc( + testAccRoleCheckDestroy(roleName1), + testAccRoleCheckDestroy(roleName2), + ), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigMultiple(roleName1, roleName2), + }, + { + ResourceName: resourceName1, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName1, + }, + { + ResourceName: resourceName2, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName2, + }, + }, + }) +} + +func TestAccRole_importSingleQuote(t *testing.T) { + roleName := "tf-test-role'quote" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importBackslash(t *testing.T) { + roleName := "tf-test-role\\\\backslash" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importDoubleQuote(t *testing.T) { + roleName := "tf-test-role\\\"quote-test" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importUnicode(t *testing.T) { + roleName := "tf-test-role-unicode-测试" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importReservedWord(t *testing.T) { + roleName := "SELECT" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importLongName(t *testing.T) { + roleName := "tf-test-role-long-" + strings.Repeat("a", 14) + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func TestAccRole_importMultipleSpecialChars(t *testing.T) { + roleName := "tf-test-role@#$%^&*()" + resourceName := "mysql_role.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { + testAccPreCheck(t) + testAccPreCheckSkipRds(t) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) + if err != nil { + return + } + + requiredVersion, _ := version.NewVersion("8.0.0") + currentVersion, err := serverVersion(db) + if err != nil { + return + } + + if currentVersion.LessThan(requiredVersion) { + t.Skip("Roles require MySQL 8+") + } + }, + ProviderFactories: testAccProviderFactories, + CheckDestroy: testAccRoleCheckDestroy(roleName), + Steps: []resource.TestStep{ + { + Config: testAccRoleConfigBasic(roleName), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateId: roleName, + }, + }, + }) +} + +func testAccRoleConfigWithGrants(roleName string) string { + return fmt.Sprintf(` +resource "mysql_role" "test" { + name = "%s" +} + +resource "mysql_grant" "test" { + user = "%s" + host = "%%" + database = "*" + table = "*" + privileges = ["SELECT"] +} +`, roleName, roleName) +} diff --git a/mysql/resource_user.go b/mysql/resource_user.go index e9d99c37..e0238011 100644 --- a/mysql/resource_user.go +++ b/mysql/resource_user.go @@ -16,8 +16,12 @@ import ( ) // formatUserIdentifier formats a user identifier with proper quoting for MySQL +// It uses single quotes with doubled-quote escaping (e.g., 'user'@'host') for better +// compatibility across MySQL/MariaDB versions and to support @ in usernames (GCP IAM). func formatUserIdentifier(user, host string) string { - return fmt.Sprintf("%s@%s", quoteIdentifier(user), quoteIdentifier(host)) + escapedUser := strings.ReplaceAll(user, "'", "''") + escapedHost := strings.ReplaceAll(host, "'", "''") + return fmt.Sprintf("'%s'@'%s'", escapedUser, escapedHost) } // quoteString escapes and quotes a string literal for MySQL @@ -597,21 +601,39 @@ func DeleteUser(ctx context.Context, d *schema.ResourceData, meta interface{}) d return diag.FromErr(err) } -func ImportUser(ctx context.Context, d *schema.ResourceData, meta interface{}) ([]*schema.ResourceData, error) { - userHost := strings.SplitN(d.Id(), "@", 2) +// parseUserHost parses a user@host string, handling usernames that contain @ +// (like GCP IAM email addresses). It uses the last @ as the separator. +func parseUserHost(id string) (user, host string, err error) { + // Find the last @ to handle usernames containing @ (like GCP IAM emails) + lastAt := strings.LastIndex(id, "@") + if lastAt == -1 { + // No @ found - use default host + return id, "localhost", nil + } + if lastAt == 0 { + return "", "", fmt.Errorf("wrong ID format %s (user cannot start with @)", id) + } + + user = id[:lastAt] + host = id[lastAt+1:] + if host == "" { + host = "localhost" + } + return user, host, nil +} - if len(userHost) != 2 { - return nil, fmt.Errorf("wrong ID format %s (expected USER@HOST)", d.Id()) +func ImportUser(ctx context.Context, d *schema.ResourceData, meta interface{}) ([]*schema.ResourceData, error) { + user, host, err := parseUserHost(d.Id()) + if err != nil { + return nil, err } - user := userHost[0] - host := userHost[1] d.Set("user", user) d.Set("host", host) - err := ReadUser(ctx, d, meta) + diags := ReadUser(ctx, d, meta) var ferror error - if err.HasError() { - ferror = fmt.Errorf("failed reading user: %v", err) + if diags.HasError() { + ferror = fmt.Errorf("failed reading user: %v", diags) } return []*schema.ResourceData{d}, ferror