Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pkg/connector/server_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ func (d *userPrincipalSyncer) CreateAccount(
return nil, nil, nil, fmt.Errorf("failed to create login: %w", err)
}

uid, err := d.client.GetUserPrincipalByName(ctx, username)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
}

// Create a resource for the newly created login
profile := map[string]interface{}{
"username": username,
Expand All @@ -176,7 +181,7 @@ func (d *userPrincipalSyncer) CreateAccount(
resource, err := resource.NewUserResource(
formattedUsername,
d.ResourceType(ctx),
formattedUsername, // Use the formatted username as the ID
uid.ID,
userOpts,
)
if err != nil {
Expand Down
15 changes: 10 additions & 5 deletions pkg/mssqldb/roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,22 @@ WHERE type = 'R' AND principal_id = @p1
return &roleModel, err
}

func (c *Client) AddUserToServerRole(ctx context.Context, role string, user string) error {
func (c *Client) AddUserToServerRole(ctx context.Context, role string, userID string) error {
l := ctxzap.Extract(ctx)
l.Debug("adding user to database role", zap.String("role", role), zap.String("user", user))
l.Debug("adding user to database role", zap.String("role", role), zap.String("userID", userID))

if strings.ContainsAny(role, "[]\"';") || strings.ContainsAny(user, "[]\"';") {
if strings.ContainsAny(role, "[]\"';") || strings.ContainsAny(userID, "[]\"';") {
return fmt.Errorf("invalid characters in role or user")
}

query := fmt.Sprintf(`ALTER SERVER ROLE [%s] ADD MEMBER [%s];`, role, user)
user, err := c.GetUserPrincipal(ctx, userID)
if err != nil {
return fmt.Errorf("cannot get user: %w", err)
}

_, err := c.db.ExecContext(ctx, query)
query := fmt.Sprintf(`ALTER SERVER ROLE [%s] ADD MEMBER [%s];`, role, user.Name)

_, err = c.db.ExecContext(ctx, query)
if err != nil {
return err
}
Expand Down
37 changes: 37 additions & 0 deletions pkg/mssqldb/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ WHERE
return &userModel, nil
}

func (c *Client) GetUserPrincipalByName(ctx context.Context, name string) (*UserModel, error) {
l := ctxzap.Extract(ctx)
l.Debug("getting user")

query := `
SELECT
principal_id,
sid,
name,
type_desc,
is_disabled
FROM
sys.server_principals
WHERE
(
type = 'S'
OR type = 'U'
OR type = 'C'
OR type = 'E'
OR type = 'K'
) AND name = @p1
`

rows := c.db.QueryRowxContext(ctx, query, name)

var userModel UserModel
err := rows.StructScan(&userModel)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("user name not found: %s", name)
}
return nil, err
}

return &userModel, nil
}

// GetUserFromDb find db user from Server principal.
func (c *Client) GetUserFromDb(ctx context.Context, db, principalId string) (*UserDBModel, error) {
l := ctxzap.Extract(ctx)
Expand Down
Loading