diff --git a/pkg/connector/server_user.go b/pkg/connector/server_user.go index 637e64f..1f549ea 100644 --- a/pkg/connector/server_user.go +++ b/pkg/connector/server_user.go @@ -151,6 +151,11 @@ func (d *userPrincipalSyncer) CreateAccount( return nil, nil, nil, fmt.Errorf("failed to create login: %w", err) } + uid, err := d.client.GetUserPrincipalByName(ctx, username) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get user: %w", err) + } + // Create a resource for the newly created login profile := map[string]interface{}{ "username": username, @@ -176,7 +181,7 @@ func (d *userPrincipalSyncer) CreateAccount( resource, err := resource.NewUserResource( formattedUsername, d.ResourceType(ctx), - formattedUsername, // Use the formatted username as the ID + uid.ID, userOpts, ) if err != nil { diff --git a/pkg/mssqldb/roles.go b/pkg/mssqldb/roles.go index 0cc5bb7..90e39c2 100644 --- a/pkg/mssqldb/roles.go +++ b/pkg/mssqldb/roles.go @@ -280,6 +280,9 @@ WHERE type = 'R' AND principal_id = @p1 var roleModel RoleModel row := c.db.QueryRowxContext(ctx, query, id) + if err := row.Err(); err != nil { + return nil, err + } err := row.StructScan(&roleModel) if err != nil { @@ -315,6 +318,9 @@ WHERE type = 'R' AND principal_id = @p1 var roleModel RoleModel row := c.db.QueryRowxContext(ctx, query, id) + if err := row.Err(); err != nil { + return nil, err + } err := row.StructScan(&roleModel) if err != nil { @@ -324,17 +330,22 @@ WHERE type = 'R' AND principal_id = @p1 return &roleModel, err } -func (c *Client) AddUserToServerRole(ctx context.Context, role string, user string) error { +func (c *Client) AddUserToServerRole(ctx context.Context, role string, userID string) error { l := ctxzap.Extract(ctx) - l.Debug("adding user to database role", zap.String("role", role), zap.String("user", user)) + l.Debug("adding user to database role", zap.String("role", role), zap.String("userID", userID)) - if strings.ContainsAny(role, "[]\"';") || strings.ContainsAny(user, "[]\"';") { + if strings.ContainsAny(role, "[]\"';") || strings.ContainsAny(userID, "[]\"';") { return fmt.Errorf("invalid characters in role or user") } - query := fmt.Sprintf(`ALTER SERVER ROLE [%s] ADD MEMBER [%s];`, role, user) + user, err := c.GetUserPrincipal(ctx, userID) + if err != nil { + return fmt.Errorf("cannot get user: %w", err) + } - _, err := c.db.ExecContext(ctx, query) + query := fmt.Sprintf(`ALTER SERVER ROLE [%s] ADD MEMBER [%s];`, role, user.Name) + + _, err = c.db.ExecContext(ctx, query) if err != nil { return err } diff --git a/pkg/mssqldb/users.go b/pkg/mssqldb/users.go index 3e7812c..1555827 100644 --- a/pkg/mssqldb/users.go +++ b/pkg/mssqldb/users.go @@ -223,6 +223,9 @@ WHERE ` rows := c.db.QueryRowxContext(ctx, query, userId) + if err := rows.Err(); err != nil { + return nil, err + } var userModel UserModel err := rows.StructScan(&userModel) @@ -236,6 +239,46 @@ WHERE return &userModel, nil } +func (c *Client) GetUserPrincipalByName(ctx context.Context, name string) (*UserModel, error) { + l := ctxzap.Extract(ctx) + l.Debug("getting user") + + query := ` +SELECT + principal_id, + sid, + name, + type_desc, + is_disabled +FROM + sys.server_principals +WHERE + ( + type = 'S' + OR type = 'U' + OR type = 'C' + OR type = 'E' + OR type = 'K' + ) AND name = @p1 +` + + rows := c.db.QueryRowxContext(ctx, query, name) + if err := rows.Err(); err != nil { + return nil, err + } + + var userModel UserModel + err := rows.StructScan(&userModel) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("user name not found: %s", name) + } + return nil, err + } + + return &userModel, nil +} + // GetUserFromDb find db user from Server principal. func (c *Client) GetUserFromDb(ctx context.Context, db, principalId string) (*UserDBModel, error) { l := ctxzap.Extract(ctx) @@ -267,6 +310,9 @@ AND sp.principal_id = @p1 query = fmt.Sprintf(query, db) row := c.db.QueryRowxContext(ctx, query, principalId) + if err := row.Err(); err != nil { + return nil, err + } var userModel UserDBModel err := row.StructScan(&userModel)