Skip to content

Commit 29ee540

Browse files
Merge pull request #30 from ConductorOne/fix/BB-972
[BB-972] fix. Function for formating username implemented
2 parents a052cdb + 6d96841 commit 29ee540

File tree

2 files changed

+50
-50
lines changed

2 files changed

+50
-50
lines changed

pkg/connector/server_user.go

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"math/big"
88
"net/mail"
9+
"strings"
910

1011
v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
1112
"github.com/conductorone/baton-sdk/pkg/annotations"
@@ -98,6 +99,7 @@ func (d *userPrincipalSyncer) CreateAccount(
9899
accountInfo *v2.AccountInfo,
99100
credentialOptions *v2.CredentialOptions,
100101
) (connectorbuilder.CreateAccountResponse, []*v2.PlaintextData, annotations.Annotations, error) {
102+
var domain, formattedUsername, password string
101103
l := ctxzap.Extract(ctx)
102104

103105
// Extract required login_type field from profile
@@ -115,43 +117,24 @@ func (d *userPrincipalSyncer) CreateAccount(
115117
}
116118
username := usernameVal.GetStringValue()
117119

118-
// Extract optional domain field (for Windows auth) or password (for SQL auth)
119-
var domain, password string
120-
var formattedUsername string
121-
122-
switch loginType {
123-
case mssqldb.LoginTypeWindows:
124-
// For Windows auth, extract domain
125-
domainVal := accountInfo.Profile.GetFields()["domain"]
126-
if domainVal != nil && domainVal.GetStringValue() != "" {
127-
domain = domainVal.GetStringValue()
128-
}
120+
domainVal := accountInfo.Profile.GetFields()["domain"]
121+
if domainVal != nil && domainVal.GetStringValue() != "" {
122+
domain = domainVal.GetStringValue()
123+
}
129124

130-
if domain != "" {
131-
formattedUsername = fmt.Sprintf("%s\\%s", domain, username)
132-
} else {
133-
formattedUsername = username
134-
}
135-
case mssqldb.LoginTypeSQL:
136-
// For SQL auth, generate a strong random password
137-
password = generateStrongPassword()
138-
l.Debug("generated random password for SQL Server authentication")
139-
formattedUsername = username
140-
case mssqldb.LoginTypeAzureAD, mssqldb.LoginTypeEntraID:
141-
// For Azure AD or Entra ID, just use the username as is
142-
formattedUsername = username
143-
default:
144-
return nil, nil, nil, fmt.Errorf("unsupported login type: %s", loginType)
125+
formattedUsername, password, err := formatUserLogin(ctx, loginType, username, domain)
126+
if err != nil {
127+
return nil, nil, nil, err
145128
}
146129

147130
// Create the login
148-
err := d.client.CreateLogin(ctx, loginType, domain, username, password)
131+
err = d.client.CreateLogin(ctx, loginType, formattedUsername, password)
149132
if err != nil {
150133
l.Error("Failed to create login", zap.Error(err), zap.String("loginType", string(loginType)))
151134
return nil, nil, nil, fmt.Errorf("failed to create login: %w", err)
152135
}
153136

154-
uid, err := d.client.GetUserPrincipalByName(ctx, username)
137+
uid, err := d.client.GetUserPrincipalByName(ctx, formattedUsername)
155138
if err != nil {
156139
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
157140
}
@@ -211,6 +194,42 @@ func (d *userPrincipalSyncer) CreateAccount(
211194
return successResult, plaintextData, nil, nil
212195
}
213196

197+
func formatUserLogin(ctx context.Context, loginType mssqldb.LoginType, username string, domain string) (string, string, error) {
198+
var formattedUsername, password string
199+
l := ctxzap.Extract(ctx)
200+
201+
// Check for invalid characters to prevent SQL injection
202+
if (domain != "" && strings.ContainsAny(domain, "[]\"';")) || strings.ContainsAny(username, "[]\"';") {
203+
return "", "", fmt.Errorf("invalid characters in domain or username")
204+
}
205+
206+
switch loginType {
207+
case mssqldb.LoginTypeWindows:
208+
if domain != "" {
209+
formattedUsername = fmt.Sprintf("%s\\%s", domain, username)
210+
l.Debug("windows login will be created with domain", zap.String("login", formattedUsername))
211+
} else {
212+
formattedUsername = username
213+
l.Debug("windows login will be created without domain", zap.String("login", formattedUsername))
214+
}
215+
216+
case mssqldb.LoginTypeSQL:
217+
// For SQL auth, generate a strong random password
218+
password = generateStrongPassword()
219+
l.Debug("generated random password for SQL Server authentication")
220+
formattedUsername = username
221+
222+
case mssqldb.LoginTypeAzureAD, mssqldb.LoginTypeEntraID:
223+
// For Azure AD or Entra ID, just use the username as is
224+
formattedUsername = username
225+
226+
default:
227+
return "", "", fmt.Errorf("unsupported login type: %s", loginType)
228+
}
229+
230+
return formattedUsername, password, nil
231+
}
232+
214233
// CreateAccountCapabilityDetails returns the capability details for account creation.
215234
func (d *userPrincipalSyncer) CreateAccountCapabilityDetails(
216235
ctx context.Context,

pkg/mssqldb/users.go

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -383,25 +383,14 @@ const (
383383
// For Entra ID authentication (loginType=ENTRA_ID):
384384
// - It creates from EXTERNAL PROVIDER
385385
// - Username should be the full Entra ID username/email
386-
func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, domain, username, password string) error {
386+
func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, username, password string) error {
387387
l := ctxzap.Extract(ctx)
388388

389-
// Check for invalid characters to prevent SQL injection
390-
if (domain != "" && strings.ContainsAny(domain, "[]\"';")) || strings.ContainsAny(username, "[]\"';") {
391-
return fmt.Errorf("invalid characters in domain or username")
392-
}
393-
394389
var query string
395390
switch loginType {
396391
case LoginTypeWindows:
397-
var loginName string
398-
if domain != "" {
399-
loginName = fmt.Sprintf("[%s\\%s]", domain, username)
400-
l.Debug("creating windows login with domain", zap.String("login", loginName))
401-
} else {
402-
loginName = fmt.Sprintf("[%s]", username)
403-
l.Debug("creating windows login without domain", zap.String("login", loginName))
404-
}
392+
loginName := fmt.Sprintf("[%s]", username)
393+
l.Debug("creating windows login", zap.String("login", loginName))
405394
query = fmt.Sprintf("CREATE LOGIN %s FROM WINDOWS;", loginName)
406395
case LoginTypeSQL:
407396
if password == "" {
@@ -429,11 +418,3 @@ func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, domain, u
429418

430419
return nil
431420
}
432-
433-
// CreateWindowsLogin creates a SQL Server login from Windows AD for the specified domain and username.
434-
// If domain is provided, it will create the login in the format [DOMAIN\Username],
435-
// otherwise it will use just [Username].
436-
// This is a convenience method that calls CreateLogin with LoginTypeWindows.
437-
func (c *Client) CreateWindowsLogin(ctx context.Context, domain, username string) error {
438-
return c.CreateLogin(ctx, LoginTypeWindows, domain, username, "")
439-
}

0 commit comments

Comments
 (0)