Skip to content

Commit 06d3fb2

Browse files
handsomejack-42bgavrilMS
authored andcommitted
feat(oauth): add support for dSTS authority type
1 parent 133b78f commit 06d3fb2

File tree

5 files changed

+99
-24
lines changed

5 files changed

+99
-24
lines changed

apps/confidential/confidential_test.go

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ import (
2121
"testing"
2222
"time"
2323

24+
"github.com/golang-jwt/jwt/v5"
25+
"github.com/kylelemons/godebug/pretty"
26+
2427
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
2528
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
2629
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
2730
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
2831
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
2932
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
3033
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
31-
"github.com/golang-jwt/jwt/v5"
32-
"github.com/kylelemons/godebug/pretty"
3334
)
3435

3536
// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
@@ -1405,3 +1406,59 @@ func TestWithAuthenticationScheme(t *testing.T) {
14051406
t.Fatalf(`unexpected access token "%s"`, result.AccessToken)
14061407
}
14071408
}
1409+
1410+
func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
1411+
tests := map[string]struct {
1412+
cred string
1413+
}{
1414+
"secret": {cred: "fake_secret"},
1415+
"signed assertion": {cred: "fake_assertion"},
1416+
}
1417+
1418+
for name, test := range tests {
1419+
t.Run(name, func(t *testing.T) {
1420+
cred, err := NewCredFromSecret(test.cred)
1421+
if err != nil {
1422+
t.Fatal(err)
1423+
}
1424+
client, err := fakeClient(accesstokens.TokenResponse{
1425+
AccessToken: token,
1426+
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
1427+
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
1428+
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
1429+
TokenType: "Bearer",
1430+
}, cred, "https://fake_authority/dstsv2/"+authority.DSTSTenant)
1431+
if err != nil {
1432+
t.Fatal(err)
1433+
}
1434+
1435+
// expect first attempt to fail
1436+
_, err = client.AcquireTokenSilent(context.Background(), tokenScope)
1437+
if err == nil {
1438+
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
1439+
}
1440+
1441+
tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
1442+
if err != nil {
1443+
t.Errorf("got err == %s, want err == nil", err)
1444+
}
1445+
if tk.AccessToken != token {
1446+
t.Errorf("unexpected access token %s", tk.AccessToken)
1447+
}
1448+
1449+
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope)
1450+
if err != nil {
1451+
t.Errorf("got err == %s, want err == nil", err)
1452+
}
1453+
if tk.AccessToken != token {
1454+
t.Errorf("unexpected access token %s", tk.AccessToken)
1455+
}
1456+
1457+
// fail for another tenant
1458+
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("other"))
1459+
if err == nil {
1460+
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
1461+
}
1462+
})
1463+
}
1464+
}

apps/internal/oauth/oauth.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"io"
1111
"time"
1212

13+
"github.com/google/uuid"
14+
1315
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
1416
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
1517
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
@@ -18,7 +20,6 @@ import (
1820
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
1921
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
2022
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
21-
"github.com/google/uuid"
2223
)
2324

2425
// ResolveEndpointer contains the methods for resolving authority endpoints.

apps/internal/oauth/ops/authority/authority.go

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,12 @@ const (
136136
const (
137137
AAD = "MSSTS"
138138
ADFS = "ADFS"
139+
DSTS = "DSTS"
139140
)
140141

142+
// DSTSTenant is referenced throughout multiple files, let us use a const in case we ever need to change it.
143+
const DSTSTenant = "7a433bfc-2514-4697-b467-e0933190487f"
144+
141145
// AuthenticationScheme is an extensibility mechanism designed to be used only by Azure Arc for proof of possession access tokens.
142146
type AuthenticationScheme interface {
143147
// Extra parameters that are added to the request to the /token endpoint.
@@ -251,6 +255,8 @@ func (p AuthParams) WithTenant(ID string) (AuthParams, error) {
251255
authority = "https://" + path.Join(p.AuthorityInfo.Host, ID)
252256
case ADFS:
253257
return p, errors.New("ADFS authority doesn't support tenants")
258+
case DSTS:
259+
return p, errors.New("dSTS authority doesn't support tenants")
254260
}
255261

256262
info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled)
@@ -350,35 +356,43 @@ type Info struct {
350356
InstanceDiscoveryDisabled bool
351357
}
352358

353-
func firstPathSegment(u *url.URL) (string, error) {
354-
pathParts := strings.Split(u.EscapedPath(), "/")
355-
if len(pathParts) >= 2 {
356-
return pathParts[1], nil
357-
}
358-
359-
return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
360-
}
361-
362359
// NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided.
363360
func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) {
364361
u, err := url.Parse(strings.ToLower(authority))
365-
if err != nil || u.Scheme != "https" {
366-
return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
362+
if err != nil {
363+
return Info{}, fmt.Errorf("couldn't parse authority url: %w", err)
364+
}
365+
if u.Scheme != "https" {
366+
return Info{}, errors.New("authority url scheme must be https")
367367
}
368368

369-
tenant, err := firstPathSegment(u)
370-
if err != nil {
371-
return Info{}, err
369+
pathParts := strings.Split(u.EscapedPath(), "/")
370+
if len(pathParts) < 2 {
371+
return Info{}, errors.New(`authority must be an URL such as "https://login.microsoftonline.com/<your tenant>"`)
372372
}
373-
authorityType := AAD
374-
if tenant == "adfs" {
373+
374+
var authorityType, tenant string
375+
switch pathParts[1] {
376+
case "adfs":
375377
authorityType = ADFS
378+
case "dstsv2":
379+
if len(pathParts) != 3 {
380+
return Info{}, fmt.Errorf("dSTS authority must be an https URL such as https://<authority>/dstsv2/%s", DSTSTenant)
381+
}
382+
if pathParts[2] != DSTSTenant {
383+
return Info{}, fmt.Errorf("dSTS authority only accepts a single tenant %q", DSTSTenant)
384+
}
385+
authorityType = DSTS
386+
tenant = DSTSTenant
387+
default:
388+
authorityType = AAD
389+
tenant = pathParts[1]
376390
}
377391

378392
// u.Host includes the port, if any, which is required for private cloud deployments
379393
return Info{
380394
Host: u.Host,
381-
CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant),
395+
CanonicalAuthorityURI: authority,
382396
AuthorityType: authorityType,
383397
ValidateAuthority: validateAuthority,
384398
Tenant: tenant,

apps/internal/oauth/ops/authority/authority_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ func TestAuthParamsWithTenant(t *testing.T) {
341341
"tenant can't be consumers for AAD": {authority: host + uuid1, tenant: "consumers", expectError: true},
342342
"tenant can't be organizations for AAD": {authority: host + uuid1, tenant: "organizations", expectError: true},
343343
"can't override tenant for ADFS ever": {authority: host + "adfs", tenant: uuid1, expectError: true},
344+
"can't override tenant for dSTS ever": {authority: host + "dstsv2/" + DSTSTenant, tenant: uuid1, expectError: true},
344345
"can't override AAD tenant consumers": {authority: host + "consumers", tenant: uuid1, expectError: true},
345346
}
346347

apps/internal/oauth/resolvers.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo
4848
return endpoints, nil
4949
}
5050

51-
endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
51+
endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo)
5252
if err != nil {
5353
return authority.Endpoints{}, err
5454
}
@@ -116,9 +116,12 @@ func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, use
116116
m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
117117
}
118118

119-
func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
120-
if authorityInfo.Tenant == "adfs" {
119+
func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info) (string, error) {
120+
if authorityInfo.AuthorityType == authority.ADFS {
121121
return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
122+
} else if authorityInfo.AuthorityType == authority.DSTS {
123+
return fmt.Sprintf("https://%s/dstsv2/%s/v2.0/.well-known/openid-configuration", authorityInfo.Host, authority.DSTSTenant), nil
124+
122125
} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
123126
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
124127
if err != nil {
@@ -131,7 +134,6 @@ func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, aut
131134
return "", err
132135
}
133136
return resp.TenantDiscoveryEndpoint, nil
134-
135137
}
136138

137139
return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil

0 commit comments

Comments
 (0)