Skip to content

Commit 8024695

Browse files
committed
sql,sessioninit: return PROVISIONSRC role option from GetUserSessionInitInfo
This allows the PROVISIONSRC that has been configured to be used during authentication. Release note: None
1 parent 7f725b0 commit 8024695

File tree

11 files changed

+233
-7
lines changed

11 files changed

+233
-7
lines changed

pkg/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ ALL_TESTS = [
338338
"//pkg/security/certmgr:certmgr_test",
339339
"//pkg/security/clientcert:clientcert_test",
340340
"//pkg/security/password:password_test",
341+
"//pkg/security/provisioning:provisioning_test",
341342
"//pkg/security/sessionrevival:sessionrevival_test",
342343
"//pkg/security/username:username_disallowed_imports_test",
343344
"//pkg/security/username:username_test",
@@ -1718,6 +1719,7 @@ GO_TARGETS = [
17181719
"//pkg/security/password:password_test",
17191720
"//pkg/security/pprompt:pprompt",
17201721
"//pkg/security/provisioning:provisioning",
1722+
"//pkg/security/provisioning:provisioning_test",
17211723
"//pkg/security/securityassets:securityassets",
17221724
"//pkg/security/securitytest:securitytest",
17231725
"//pkg/security/sessionrevival:sessionrevival",

pkg/ccl/serverccl/role_authentication_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ func TestVerifyPassword(t *testing.T) {
362362
t.Run(tc.testName, func(t *testing.T) {
363363
execCfg := ts.ExecutorConfig().(sql.ExecutorConfig)
364364
username := username.MakeSQLUsernameFromPreNormalizedString(tc.username)
365-
exists, canLoginSQL, canLoginDBConsole, canUseReplicationMode, isSuperuser, _, _, pwRetrieveFn, err := sql.GetUserSessionInitInfo(
365+
exists, canLoginSQL, canLoginDBConsole, canUseReplicationMode, isSuperuser, _, _, _, pwRetrieveFn, err := sql.GetUserSessionInitInfo(
366366
context.Background(), &execCfg, username, "", /* databaseName */
367367
)
368368

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
load("@io_bazel_rules_go//go:def.bzl", "go_library")
1+
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
22

33
go_library(
44
name = "provisioning",
@@ -7,3 +7,13 @@ go_library(
77
visibility = ["//visibility:public"],
88
deps = ["@com_github_cockroachdb_errors//:errors"],
99
)
10+
11+
go_test(
12+
name = "provisioning_test",
13+
srcs = ["provisioning_source_test.go"],
14+
embed = [":provisioning"],
15+
deps = [
16+
"//pkg/util/leaktest",
17+
"@com_github_stretchr_testify//require",
18+
],
19+
)

pkg/security/provisioning/provisioning_source.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,18 @@ func parseAuthMethod(sourceStr string) (authMethod string, idp string, err error
6464
}
6565

6666
func parseIDP(idp string) (u *url.URL, err error) {
67+
if len(idp) == 0 {
68+
return nil, errors.Newf("PROVISIONSRC IDP cannot be empty")
69+
}
6770
if u, err = url.Parse(idp); err != nil {
6871
return nil, errors.Wrapf(err, "provided IDP %q in PROVISIONSRC is non parseable", idp)
6972
}
73+
if len(u.Port()) != 0 || u.Opaque != "" {
74+
return nil, errors.Newf("unknown PROVISIONSRC IDP url format in %q", idp)
75+
}
7076
return
7177
}
78+
79+
func (source *Source) Size() int {
80+
return len(source.authMethod) + len(source.idp.String())
81+
}
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package provisioning
7+
8+
import (
9+
"testing"
10+
11+
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestParseProvisioningSource(t *testing.T) {
16+
defer leaktest.AfterTest(t)()
17+
tests := []struct {
18+
name string
19+
sourceStr string
20+
wantErr bool
21+
expectedMethod string
22+
expectedIDP string
23+
expectedErrMsg string
24+
}{
25+
{
26+
name: "valid ldap source",
27+
sourceStr: "ldap:ldap.bar.com",
28+
wantErr: false,
29+
expectedMethod: "ldap",
30+
expectedIDP: "ldap.bar.com",
31+
},
32+
{
33+
name: "valid ldap source with example.com",
34+
sourceStr: "ldap:ldap.example.com",
35+
wantErr: false,
36+
expectedMethod: "ldap",
37+
expectedIDP: "ldap.example.com",
38+
},
39+
{
40+
name: "valid ldap source with simple hostname",
41+
sourceStr: "ldap:foo.bar",
42+
wantErr: false,
43+
expectedMethod: "ldap",
44+
expectedIDP: "foo.bar",
45+
},
46+
{
47+
name: "missing auth method prefix",
48+
sourceStr: "ldap.example.com",
49+
wantErr: true,
50+
expectedErrMsg: `PROVISIONSRC "ldap.example.com" was not prefixed with any valid auth methods ["ldap"]`,
51+
},
52+
{
53+
name: "invalid characters in IDP",
54+
sourceStr: "ldap:[]!@#%#^$&*",
55+
wantErr: true,
56+
expectedErrMsg: `provided IDP "[]!@#%#^$&*" in PROVISIONSRC is non parseable`,
57+
},
58+
{
59+
name: "empty string",
60+
sourceStr: "",
61+
wantErr: true,
62+
expectedErrMsg: `PROVISIONSRC "" was not prefixed with any valid auth methods ["ldap"]`,
63+
},
64+
{
65+
name: "invalid auth method",
66+
sourceStr: "oauth:example.com",
67+
wantErr: true,
68+
expectedErrMsg: `PROVISIONSRC "oauth:example.com" was not prefixed with any valid auth methods ["ldap"]`,
69+
},
70+
{
71+
name: "only auth method without IDP",
72+
sourceStr: "ldap:",
73+
wantErr: true,
74+
expectedErrMsg: `PROVISIONSRC IDP cannot be empty`,
75+
},
76+
{
77+
name: "IDP url with port",
78+
sourceStr: "ldap:example.com:389",
79+
wantErr: true,
80+
expectedErrMsg: "unknown PROVISIONSRC IDP url format in \"example.com:389\"",
81+
},
82+
{
83+
name: "IDP url starts with double slash",
84+
sourceStr: "ldap://ldap.example.com",
85+
wantErr: false,
86+
expectedMethod: "ldap",
87+
expectedIDP: "//ldap.example.com",
88+
},
89+
{
90+
name: "space in IDP url",
91+
sourceStr: "ldap:ldap1 ldap2",
92+
wantErr: false,
93+
expectedMethod: "ldap",
94+
expectedIDP: "ldap1%20ldap2",
95+
},
96+
}
97+
98+
for _, tt := range tests {
99+
t.Run(tt.name, func(t *testing.T) {
100+
source, err := ParseProvisioningSource(tt.sourceStr)
101+
if tt.wantErr {
102+
require.Error(t, err)
103+
require.Nil(t, source)
104+
require.Contains(t, err.Error(), tt.expectedErrMsg)
105+
} else {
106+
require.NoError(t, err)
107+
require.NotNil(t, source)
108+
require.Equal(t, tt.expectedMethod, source.authMethod)
109+
require.Equal(t, tt.expectedIDP, source.idp.String())
110+
}
111+
})
112+
}
113+
}
114+
115+
func TestValidateSource(t *testing.T) {
116+
defer leaktest.AfterTest(t)()
117+
tests := []struct {
118+
name string
119+
sourceStr string
120+
wantErr bool
121+
expectedErrMsg string
122+
}{
123+
{
124+
name: "valid ldap source",
125+
sourceStr: "ldap:ldap.bar.com",
126+
wantErr: false,
127+
},
128+
{
129+
name: "valid ldap source with example.com",
130+
sourceStr: "ldap:ldap.example.com",
131+
wantErr: false,
132+
},
133+
{
134+
name: "valid ldap source with simple hostname",
135+
sourceStr: "ldap:foo.bar",
136+
wantErr: false,
137+
},
138+
{
139+
name: "missing auth method prefix",
140+
sourceStr: "ldap.example.com",
141+
wantErr: true,
142+
expectedErrMsg: `PROVISIONSRC "ldap.example.com" was not prefixed with any valid auth methods ["ldap"]`,
143+
},
144+
{
145+
name: "invalid characters in IDP",
146+
sourceStr: "ldap:[]!@#%#^$&*",
147+
wantErr: true,
148+
expectedErrMsg: `provided IDP "[]!@#%#^$&*" in PROVISIONSRC is non parseable`,
149+
},
150+
{
151+
name: "empty string",
152+
sourceStr: "",
153+
wantErr: true,
154+
expectedErrMsg: `PROVISIONSRC "" was not prefixed with any valid auth methods ["ldap"]`,
155+
},
156+
{
157+
name: "invalid auth method",
158+
sourceStr: "oauth:example.com",
159+
wantErr: true,
160+
expectedErrMsg: `PROVISIONSRC "oauth:example.com" was not prefixed with any valid auth methods ["ldap"]`,
161+
},
162+
{
163+
name: "only auth method without IDP",
164+
sourceStr: "ldap:",
165+
wantErr: true,
166+
expectedErrMsg: `PROVISIONSRC IDP cannot be empty`,
167+
},
168+
}
169+
170+
for _, tt := range tests {
171+
t.Run(tt.name, func(t *testing.T) {
172+
err := ValidateSource(tt.sourceStr)
173+
if tt.wantErr {
174+
require.Error(t, err)
175+
require.Contains(t, err.Error(), tt.expectedErrMsg)
176+
} else {
177+
require.NoError(t, err)
178+
}
179+
})
180+
}
181+
}

pkg/server/authserver/authentication.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ func (s *authenticationServer) UserLoginFromSSO(
335335
// without further normalization.
336336
username, _ := username.MakeSQLUsernameFromUserInput(reqUsername, username.PurposeValidation)
337337

338-
exists, _, canLoginDBConsole, _, _, _, _, _, err := sql.GetUserSessionInitInfo(
338+
exists, _, canLoginDBConsole, _, _, _, _, _, _, err := sql.GetUserSessionInitInfo(
339339
ctx,
340340
s.sqlServer.ExecutorConfig(),
341341
username,
@@ -493,7 +493,7 @@ func (s *authenticationServer) VerifyUserSessionDBConsole(
493493
pwRetrieveFn func(ctx context.Context) (expired bool, hashedPassword password.PasswordHash, err error),
494494
err error,
495495
) {
496-
exists, _, canLoginDBConsole, _, _, _, _, pwRetrieveFn, err := sql.GetUserSessionInitInfo(
496+
exists, _, canLoginDBConsole, _, _, _, _, _, pwRetrieveFn, err := sql.GetUserSessionInitInfo(
497497
ctx,
498498
s.sqlServer.ExecutorConfig(),
499499
userName,

pkg/sql/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ go_library(
361361
"//pkg/security",
362362
"//pkg/security/distinguishedname",
363363
"//pkg/security/password",
364+
"//pkg/security/provisioning",
364365
"//pkg/security/sessionrevival",
365366
"//pkg/security/username",
366367
"//pkg/server/license",

pkg/sql/pgwire/auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func (c *conn) handleAuthentication(
150150

151151
// Check that the requested user exists and retrieve the hashed
152152
// password in case password authentication is needed.
153-
exists, canLoginSQL, _, canUseReplicationMode, isSuperuser, defaultSettings, roleSubject, pwRetrievalFn, err :=
153+
exists, canLoginSQL, _, canUseReplicationMode, isSuperuser, defaultSettings, roleSubject, _, pwRetrievalFn, err :=
154154
sql.GetUserSessionInitInfo(
155155
ctx,
156156
execCfg,

pkg/sql/sessioninit/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ go_library(
1111
deps = [
1212
"//pkg/keys",
1313
"//pkg/security/password",
14+
"//pkg/security/provisioning",
1415
"//pkg/security/username",
1516
"//pkg/settings",
1617
"//pkg/settings/cluster",

pkg/sql/sessioninit/cache.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/cockroachdb/cockroach/pkg/keys"
1414
"github.com/cockroachdb/cockroach/pkg/security/password"
15+
"github.com/cockroachdb/cockroach/pkg/security/provisioning"
1516
"github.com/cockroachdb/cockroach/pkg/security/username"
1617
"github.com/cockroachdb/cockroach/pkg/settings"
1718
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
@@ -76,6 +77,10 @@ type AuthInfo struct {
7677
// Subject is the SUBJECT role option. It is used to match the subject
7778
// distinguished name in a client certificate.
7879
Subject *ldap.DN
80+
// ProvisioningSource is the PROVISIONSRC role option. It is used to
81+
// identify the source of the user in case a user auto provisioned from an
82+
// auth method integration.
83+
ProvisioningSource *provisioning.Source
7984
}
8085

8186
// SettingsCacheKey is the key used for the settingsCache.
@@ -261,9 +266,13 @@ func (a *Cache) maybeWriteAuthInfoBackToCache(
261266
}
262267
}
263268
}
269+
provisioningSourceSize := 0
270+
if aInfo.ProvisioningSource != nil {
271+
provisioningSourceSize += aInfo.ProvisioningSource.Size()
272+
}
264273

265274
sizeOfEntry := sizeOfUsername + len(user.Normalized()) +
266-
sizeOfAuthInfo + hpSize + sizeOfTimestamp + subjectSize
275+
sizeOfAuthInfo + hpSize + sizeOfTimestamp + subjectSize + provisioningSourceSize
267276
if err := a.boundAccount.Grow(ctx, int64(sizeOfEntry)); err != nil {
268277
// If there is no memory available to cache the entry, we can still
269278
// proceed with authentication so that users are not locked out of

0 commit comments

Comments
 (0)