Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pkg/connector/server_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ func (d *userPrincipalSyncer) CreateAccount(
return nil, nil, nil, fmt.Errorf("failed to create login: %w", err)
}

uid, err := d.client.GetUserPrincipalByName(ctx, username)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
}

// Create a resource for the newly created login
profile := map[string]interface{}{
"username": username,
Expand All @@ -176,7 +181,7 @@ func (d *userPrincipalSyncer) CreateAccount(
resource, err := resource.NewUserResource(
formattedUsername,
d.ResourceType(ctx),
formattedUsername, // Use the formatted username as the ID
uid.ID,
userOpts,
)
if err != nil {
Expand Down
21 changes: 16 additions & 5 deletions pkg/mssqldb/roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ WHERE type = 'R' AND principal_id = @p1

var roleModel RoleModel
row := c.db.QueryRowxContext(ctx, query, id)
if err := row.Err(); err != nil {
return nil, err
}

err := row.StructScan(&roleModel)
if err != nil {
Expand Down Expand Up @@ -315,6 +318,9 @@ WHERE type = 'R' AND principal_id = @p1

var roleModel RoleModel
row := c.db.QueryRowxContext(ctx, query, id)
if err := row.Err(); err != nil {
return nil, err
}

err := row.StructScan(&roleModel)
if err != nil {
Expand All @@ -324,17 +330,22 @@ WHERE type = 'R' AND principal_id = @p1
return &roleModel, err
}

func (c *Client) AddUserToServerRole(ctx context.Context, role string, user string) error {
func (c *Client) AddUserToServerRole(ctx context.Context, role string, userID string) error {
l := ctxzap.Extract(ctx)
l.Debug("adding user to database role", zap.String("role", role), zap.String("user", user))
l.Debug("adding user to database role", zap.String("role", role), zap.String("userID", userID))

if strings.ContainsAny(role, "[]\"';") || strings.ContainsAny(user, "[]\"';") {
if strings.ContainsAny(role, "[]\"';") || strings.ContainsAny(userID, "[]\"';") {
return fmt.Errorf("invalid characters in role or user")
}

query := fmt.Sprintf(`ALTER SERVER ROLE [%s] ADD MEMBER [%s];`, role, user)
user, err := c.GetUserPrincipal(ctx, userID)
if err != nil {
return fmt.Errorf("cannot get user: %w", err)
}

_, err := c.db.ExecContext(ctx, query)
query := fmt.Sprintf(`ALTER SERVER ROLE [%s] ADD MEMBER [%s];`, role, user.Name)

_, err = c.db.ExecContext(ctx, query)
if err != nil {
return err
}
Expand Down
46 changes: 46 additions & 0 deletions pkg/mssqldb/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ WHERE
`

rows := c.db.QueryRowxContext(ctx, query, userId)
if err := rows.Err(); err != nil {
return nil, err
}

var userModel UserModel
err := rows.StructScan(&userModel)
Expand All @@ -236,6 +239,46 @@ WHERE
return &userModel, nil
}

func (c *Client) GetUserPrincipalByName(ctx context.Context, name string) (*UserModel, error) {
l := ctxzap.Extract(ctx)
l.Debug("getting user")

query := `
SELECT
principal_id,
sid,
name,
type_desc,
is_disabled
FROM
sys.server_principals
WHERE
(
type = 'S'
OR type = 'U'
OR type = 'C'
OR type = 'E'
OR type = 'K'
) AND name = @p1
`

rows := c.db.QueryRowxContext(ctx, query, name)
if err := rows.Err(); err != nil {
return nil, err
}

var userModel UserModel
err := rows.StructScan(&userModel)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("user name not found: %s", name)
}
return nil, err
}

return &userModel, nil
}

// GetUserFromDb find db user from Server principal.
func (c *Client) GetUserFromDb(ctx context.Context, db, principalId string) (*UserDBModel, error) {
l := ctxzap.Extract(ctx)
Expand Down Expand Up @@ -267,6 +310,9 @@ AND sp.principal_id = @p1
query = fmt.Sprintf(query, db)

row := c.db.QueryRowxContext(ctx, query, principalId)
if err := row.Err(); err != nil {
return nil, err
}

var userModel UserDBModel
err := row.StructScan(&userModel)
Expand Down
Loading