Skip to content

Commit 68bac57

Browse files
btiplingclaude
andcommitted
[BB-610] Add support for multiple authentication types
- Add support for different login types: Windows, SQL Server, Azure AD, and Entra ID - Use FROM EXTERNAL PROVIDER for Azure AD and Entra ID authentication - Improve error handling and validation for different authentication types - Refactor code to use switch statement for better organization 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 033538e commit 68bac57

File tree

3 files changed

+140
-37
lines changed

3 files changed

+140
-37
lines changed

pkg/connector/connector.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,45 @@ func (o *Mssqldb) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) {
3737
Description: "Baton connector for Microsoft SQL Server connector",
3838
AccountCreationSchema: &v2.ConnectorAccountCreationSchema{
3939
FieldMap: map[string]*v2.ConnectorAccountCreationSchema_Field{
40+
"login_type": {
41+
DisplayName: "Login Type",
42+
Required: true,
43+
Description: "The type of SQL Server authentication to use (WINDOWS, SQL, AZURE_AD, or ENTRA_ID).",
44+
Field: &v2.ConnectorAccountCreationSchema_Field_StringField{
45+
StringField: &v2.ConnectorAccountCreationSchema_StringField{},
46+
},
47+
Placeholder: "WINDOWS",
48+
Order: 1,
49+
},
4050
"domain": {
4151
DisplayName: "Active Directory Domain",
4252
Required: false,
43-
Description: "The Active Directory domain for the user (optional). If provided, the login will be created as [DOMAIN\\Username].",
53+
Description: "The Active Directory domain for the user. Only used for Windows Authentication.",
4454
Field: &v2.ConnectorAccountCreationSchema_Field_StringField{
4555
StringField: &v2.ConnectorAccountCreationSchema_StringField{},
4656
},
4757
Placeholder: "DOMAIN",
48-
Order: 1,
58+
Order: 2,
4959
},
5060
"username": {
5161
DisplayName: "Username",
5262
Required: true,
53-
Description: "The Active Directory username for which to create a SQL Server login.",
63+
Description: "The username for which to create a SQL Server login.",
5464
Field: &v2.ConnectorAccountCreationSchema_Field_StringField{
5565
StringField: &v2.ConnectorAccountCreationSchema_StringField{},
5666
},
5767
Placeholder: "username",
58-
Order: 2,
68+
Order: 3,
69+
},
70+
"password": {
71+
DisplayName: "Password",
72+
Required: false,
73+
Description: "The password for SQL Server authentication. Required when using SQL Server Authentication.",
74+
Field: &v2.ConnectorAccountCreationSchema_Field_StringField{
75+
StringField: &v2.ConnectorAccountCreationSchema_StringField{},
76+
},
77+
Placeholder: "password",
78+
Order: 4,
5979
},
6080
},
6181
},

pkg/connector/server_user.go

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func (d *userPrincipalSyncer) Grants(ctx context.Context, resource *v2.Resource,
8787
return nil, "", nil, nil
8888
}
8989

90-
// CreateAccount creates a SQL Server login for an Active Directory user without adding database users.
90+
// CreateAccount creates a SQL Server login based on the specified login type.
9191
// It implements the AccountManager interface.
9292
func (d *userPrincipalSyncer) CreateAccount(
9393
ctx context.Context,
@@ -96,42 +96,72 @@ func (d *userPrincipalSyncer) CreateAccount(
9696
) (connectorbuilder.CreateAccountResponse, []*v2.PlaintextData, annotations.Annotations, error) {
9797
l := ctxzap.Extract(ctx)
9898

99+
// Extract required login_type field from profile
100+
loginTypeVal := accountInfo.Profile.GetFields()["login_type"]
101+
if loginTypeVal == nil || loginTypeVal.GetStringValue() == "" {
102+
return nil, nil, nil, fmt.Errorf("missing required login_type field")
103+
}
104+
loginTypeStr := loginTypeVal.GetStringValue()
105+
loginType := mssqldb.LoginType(loginTypeStr)
106+
99107
// Extract required username field from profile
100108
usernameVal := accountInfo.Profile.GetFields()["username"]
101109
if usernameVal == nil || usernameVal.GetStringValue() == "" {
102110
return nil, nil, nil, fmt.Errorf("missing required username field")
103111
}
104112
username := usernameVal.GetStringValue()
105113

106-
// Extract optional domain field from profile
107-
var domain string
108-
domainVal := accountInfo.Profile.GetFields()["domain"]
109-
if domainVal != nil && domainVal.GetStringValue() != "" {
110-
domain = domainVal.GetStringValue()
111-
}
114+
// Extract optional domain field (for Windows auth) or password (for SQL auth)
115+
var domain, password string
116+
var formattedUsername string
112117

113-
// Create the Windows login
114-
err := d.client.CreateWindowsLogin(ctx, domain, username)
115-
if err != nil {
116-
l.Error("Failed to create Windows login", zap.Error(err))
117-
return nil, nil, nil, fmt.Errorf("failed to create Windows login: %w", err)
118-
}
118+
switch loginType {
119+
case mssqldb.LoginTypeWindows:
120+
// For Windows auth, extract domain
121+
domainVal := accountInfo.Profile.GetFields()["domain"]
122+
if domainVal != nil && domainVal.GetStringValue() != "" {
123+
domain = domainVal.GetStringValue()
124+
}
119125

120-
// Determine the formatted username for the login
121-
var formattedUsername string
122-
if domain != "" {
123-
formattedUsername = fmt.Sprintf("%s\\%s", domain, username)
124-
} else {
126+
if domain != "" {
127+
formattedUsername = fmt.Sprintf("%s\\%s", domain, username)
128+
} else {
129+
formattedUsername = username
130+
}
131+
case mssqldb.LoginTypeSQL:
132+
// For SQL auth, extract password
133+
passwordVal := accountInfo.Profile.GetFields()["password"]
134+
if passwordVal == nil || passwordVal.GetStringValue() == "" {
135+
return nil, nil, nil, fmt.Errorf("missing required password field for SQL Server authentication")
136+
}
137+
password = passwordVal.GetStringValue()
138+
formattedUsername = username
139+
case mssqldb.LoginTypeAzureAD, mssqldb.LoginTypeEntraID:
140+
// For Azure AD or Entra ID, just use the username as is
125141
formattedUsername = username
142+
default:
143+
return nil, nil, nil, fmt.Errorf("unsupported login type: %s", loginType)
144+
}
145+
146+
// Create the login
147+
err := d.client.CreateLogin(ctx, loginType, domain, username, password)
148+
if err != nil {
149+
l.Error("Failed to create login", zap.Error(err), zap.String("loginType", string(loginType)))
150+
return nil, nil, nil, fmt.Errorf("failed to create login: %w", err)
126151
}
127152

128153
// Create a resource for the newly created login
129154
profile := map[string]interface{}{
130155
"username": username,
131-
"domain": domain,
156+
"login_type": string(loginType),
132157
"formatted_login": formattedUsername,
133158
}
134159

160+
// Add domain if it exists (for Windows auth)
161+
if domain != "" {
162+
profile["domain"] = domain
163+
}
164+
135165
// Use email as name if it looks like an email address
136166
var userOpts []resource.UserTraitOption
137167
userOpts = append(userOpts, resource.WithUserProfile(profile))

pkg/mssqldb/users.go

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -307,34 +307,87 @@ CREATE USER [%s] FOR LOGIN [%s];
307307
return nil
308308
}
309309

310-
// CreateWindowsLogin creates a SQL Server login from Windows AD for the specified domain and username.
311-
// If domain is provided, it will create the login in the format [DOMAIN\Username],
312-
// otherwise it will use just [Username].
313-
func (c *Client) CreateWindowsLogin(ctx context.Context, domain, username string) error {
310+
// LoginType represents the SQL Server login type.
311+
type LoginType string
312+
313+
const (
314+
// LoginTypeWindows represents Windows authentication.
315+
LoginTypeWindows LoginType = "WINDOWS"
316+
// LoginTypeSQL represents SQL Server authentication.
317+
LoginTypeSQL LoginType = "SQL"
318+
// LoginTypeAzureAD represents Azure AD authentication.
319+
LoginTypeAzureAD LoginType = "AZURE_AD"
320+
// LoginTypeEntraID represents Azure Entra ID authentication.
321+
LoginTypeEntraID LoginType = "ENTRA_ID"
322+
)
323+
324+
// CreateLogin creates a SQL Server login with the specified authentication type.
325+
// For Windows authentication (loginType=WINDOWS):
326+
// - If domain is provided, it will create the login in the format [DOMAIN\Username]
327+
// - otherwise it will use just [Username]
328+
//
329+
// For SQL authentication (loginType=SQL):
330+
// - It requires a password
331+
// - Domain is ignored
332+
//
333+
// For Azure AD authentication (loginType=AZURE_AD):
334+
// - It creates from EXTERNAL PROVIDER
335+
// - Username should be the full Azure AD username/email
336+
//
337+
// For Entra ID authentication (loginType=ENTRA_ID):
338+
// - It creates from EXTERNAL PROVIDER
339+
// - Username should be the full Entra ID username/email
340+
func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, domain, username, password string) error {
314341
l := ctxzap.Extract(ctx)
315342

316343
// Check for invalid characters to prevent SQL injection
317344
if (domain != "" && strings.ContainsAny(domain, "[]\"';")) || strings.ContainsAny(username, "[]\"';") {
318345
return fmt.Errorf("invalid characters in domain or username")
319346
}
320347

321-
var loginName string
322-
if domain != "" {
323-
loginName = fmt.Sprintf("[%s\\%s]", domain, username)
324-
l.Debug("creating windows login with domain", zap.String("login", loginName))
325-
} else {
326-
loginName = fmt.Sprintf("[%s]", username)
327-
l.Debug("creating windows login without domain", zap.String("login", loginName))
348+
var query string
349+
switch loginType {
350+
case LoginTypeWindows:
351+
var loginName string
352+
if domain != "" {
353+
loginName = fmt.Sprintf("[%s\\%s]", domain, username)
354+
l.Debug("creating windows login with domain", zap.String("login", loginName))
355+
} else {
356+
loginName = fmt.Sprintf("[%s]", username)
357+
l.Debug("creating windows login without domain", zap.String("login", loginName))
358+
}
359+
query = fmt.Sprintf("CREATE LOGIN %s FROM WINDOWS;", loginName)
360+
case LoginTypeSQL:
361+
if password == "" {
362+
return fmt.Errorf("password is required for SQL Server authentication")
363+
}
364+
// For SQL Server authentication, only username and password are used
365+
loginName := fmt.Sprintf("[%s]", username)
366+
l.Debug("creating SQL login", zap.String("login", loginName))
367+
query = fmt.Sprintf("CREATE LOGIN %s WITH PASSWORD = '%s';", loginName, password)
368+
case LoginTypeAzureAD, LoginTypeEntraID:
369+
// Azure AD and Entra ID use external provider
370+
loginName := fmt.Sprintf("[%s]", username)
371+
l.Debug("creating external provider login", zap.String("login", loginName), zap.String("type", string(loginType)))
372+
query = fmt.Sprintf("CREATE LOGIN %s FROM EXTERNAL PROVIDER;", loginName)
373+
default:
374+
return fmt.Errorf("unsupported login type: %s", loginType)
328375
}
329376

330-
query := fmt.Sprintf("CREATE LOGIN %s FROM WINDOWS;", loginName)
331-
332377
l.Debug("SQL QUERY", zap.String("q", query))
333378

334379
_, err := c.db.ExecContext(ctx, query)
335380
if err != nil {
336-
return fmt.Errorf("failed to create Windows login: %w", err)
381+
return fmt.Errorf("failed to create login: %w", err)
337382
}
338383

339384
return nil
340385
}
386+
387+
// CreateWindowsLogin creates a SQL Server login from Windows AD for the specified domain and username.
388+
// If domain is provided, it will create the login in the format [DOMAIN\Username],
389+
// otherwise it will use just [Username].
390+
// This is a convenience method that calls CreateLogin with LoginTypeWindows.
391+
func (c *Client) CreateWindowsLogin(ctx context.Context, domain, username string) error {
392+
return c.CreateLogin(ctx, LoginTypeWindows, domain, username, "")
393+
}

0 commit comments

Comments
 (0)