Skip to content

Commit a45cd8c

Browse files
committed
integration tests and benchmarks
1 parent 7a05cec commit a45cd8c

File tree

3 files changed

+426
-0
lines changed

3 files changed

+426
-0
lines changed

extensions/integration_test.go

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package extensions
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"net/http"
10+
"path/filepath"
11+
"strings"
12+
"sync"
13+
"testing"
14+
15+
"github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor/file"
16+
"github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache"
17+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
18+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
19+
"github.com/stretchr/testify/require"
20+
)
21+
22+
var ctx = context.Background()
23+
24+
func TestConfidentialClient(t *testing.T) {
25+
t.Parallel()
26+
p := filepath.Join(t.TempDir(), t.Name())
27+
a, err := file.New(p)
28+
require.NoError(t, err)
29+
c, err := cache.New(a, p+".timestamp")
30+
require.NoError(t, err)
31+
cred, err := confidential.NewCredFromSecret("*")
32+
require.NoError(t, err)
33+
client, err := confidential.New(
34+
"https://login.microsoftonline.com/tenant", "clientID", cred, confidential.WithCache(c), confidential.WithHTTPClient(&mockSTS{}),
35+
)
36+
require.NoError(t, err)
37+
38+
gr := 20
39+
wg := sync.WaitGroup{}
40+
for i := 0; i < gr; i++ {
41+
wg.Add(1)
42+
go func(n int) {
43+
defer wg.Done()
44+
if t.Failed() {
45+
return
46+
}
47+
s := fmt.Sprint(n)
48+
ar, err := client.AcquireTokenByCredential(ctx, []string{s})
49+
switch {
50+
case err != nil:
51+
t.Error(err)
52+
case ar.AccessToken != s:
53+
t.Errorf("possible test bug: expected %q from STS, got %q", s, ar.AccessToken)
54+
default:
55+
ar, err = client.AcquireTokenSilent(ctx, []string{s})
56+
if err != nil {
57+
t.Error(err)
58+
} else if ar.AccessToken != s {
59+
t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken)
60+
}
61+
}
62+
}(i)
63+
}
64+
wg.Wait()
65+
if t.Failed() {
66+
return
67+
}
68+
69+
// cache should have an access token from each goroutine
70+
lost := gr
71+
for i := 0; i < gr; i++ {
72+
s := fmt.Sprint(i)
73+
ar, err := client.AcquireTokenSilent(ctx, []string{s})
74+
if err == nil {
75+
lost--
76+
if ar.AccessToken != s {
77+
t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken)
78+
}
79+
}
80+
}
81+
require.Equal(t, 0, lost, "lost %d/%d tokens", lost, gr)
82+
}
83+
84+
func TestPublicClient(t *testing.T) {
85+
t.Parallel()
86+
p := filepath.Join(t.TempDir(), t.Name())
87+
a, err := file.New(p)
88+
require.NoError(t, err)
89+
c, err := cache.New(a, p+".timestamp")
90+
require.NoError(t, err)
91+
sts := mockSTS{}
92+
client, err := public.New("clientID", public.WithCache(c), public.WithHTTPClient(&sts))
93+
require.NoError(t, err)
94+
95+
gr := 20
96+
wg := sync.WaitGroup{}
97+
for i := 0; i < gr; i++ {
98+
wg.Add(1)
99+
go func(n int) {
100+
defer wg.Done()
101+
if t.Failed() {
102+
return
103+
}
104+
s := fmt.Sprint(n)
105+
ar, err := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password")
106+
switch {
107+
case err != nil:
108+
t.Error(err)
109+
case ar.AccessToken != s:
110+
t.Errorf("possible test bug: expected %q from STS, got %q", s, ar.AccessToken)
111+
default:
112+
ar, err = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account))
113+
if err != nil {
114+
t.Error(err)
115+
} else if ar.AccessToken != s {
116+
t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken)
117+
}
118+
}
119+
}(i)
120+
}
121+
wg.Wait()
122+
if t.Failed() {
123+
return
124+
}
125+
126+
accounts, err := client.Accounts(ctx)
127+
require.NoError(t, err)
128+
require.Equal(t, gr, len(accounts), "should have a cached account for each goroutine")
129+
130+
// Verify no access token cached above was lost due to a race. Silent auth should return a cached
131+
// access token given any scope above. A token request during this loop indicates the client
132+
// exchanged a refresh token to reacquire the access token it should have found in the cache.
133+
lostATs, reqs := 0, 0
134+
sts.tokenRequestCallback = func(*http.Request) { reqs++ }
135+
for _, a := range accounts {
136+
s, _, found := strings.Cut(a.HomeAccountID, ".")
137+
require.True(t, found, "unexpected home account ID %q", a.HomeAccountID)
138+
ar, err := client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(a))
139+
if err != nil {
140+
// the cache has no access token for the expected scope and no refresh token for the account
141+
lostATs++
142+
} else if ar.AccessToken != s {
143+
t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken)
144+
}
145+
}
146+
require.Equal(t, 0, lostATs+reqs, "lost %d/%d access tokens", reqs, gr)
147+
148+
// The cache has all the expected access tokens but may have lost refresh tokens, so we try silent
149+
// auth again for each account, passing a new scope to force the client to use a refresh token.
150+
lostRTs := 0
151+
for _, a := range accounts {
152+
s := "novelscope"
153+
ar, err := client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(a))
154+
if err != nil {
155+
lostRTs++
156+
} else if ar.AccessToken != s {
157+
t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken)
158+
}
159+
}
160+
require.Equal(t, 0, lostRTs, "lost %d/%d refresh tokens", lostRTs, gr)
161+
}

extensions/mock_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package extensions
5+
6+
import (
7+
"bytes"
8+
"encoding/base64"
9+
"fmt"
10+
"io"
11+
"net/http"
12+
"strings"
13+
)
14+
15+
// mockSTS returns mock Azure AD responses so tests don't have to account for MSAL metadata requests
16+
type mockSTS struct {
17+
tokenRequestCallback func(*http.Request)
18+
}
19+
20+
func (m *mockSTS) Do(req *http.Request) (*http.Response, error) {
21+
res := http.Response{StatusCode: http.StatusOK}
22+
switch s := strings.Split(req.URL.Path, "/"); s[len(s)-1] {
23+
case "instance":
24+
res.Body = io.NopCloser(bytes.NewReader(instanceMetadata("tenant")))
25+
case "openid-configuration":
26+
res.Body = io.NopCloser(bytes.NewReader(tenantMetadata("tenant")))
27+
case "token":
28+
if m.tokenRequestCallback != nil {
29+
m.tokenRequestCallback(req)
30+
}
31+
if err := req.ParseForm(); err != nil {
32+
return nil, err
33+
}
34+
scope := strings.Split(req.FormValue("scope"), " ")[0]
35+
userinfo := ""
36+
if upn := req.FormValue("username"); upn != "" {
37+
clientinfo := base64.RawStdEncoding.EncodeToString([]byte(fmt.Sprintf(`{"uid":"%s","utid":"utid"}`, upn)))
38+
userinfo = fmt.Sprintf(`, "client_info":"%s", "id_token":"x.e30", "refresh_token": "rt"`, clientinfo)
39+
}
40+
res.Body = io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"access_token": %q, "expires_in": 3600%s}`, scope, userinfo))))
41+
default:
42+
// User realm metadata request paths look like "/common/UserRealm/user@domain".
43+
// Matching on the UserRealm segment avoids having to know the UPN.
44+
if s[len(s)-2] == "UserRealm" {
45+
res.Body = io.NopCloser(
46+
strings.NewReader(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`),
47+
)
48+
} else {
49+
panic("unexpected request " + req.URL.String())
50+
}
51+
}
52+
return &res, nil
53+
}
54+
55+
func (m *mockSTS) CloseIdleConnections() {}
56+
57+
func instanceMetadata(tenant string) []byte {
58+
return []byte(strings.ReplaceAll(`{
59+
"tenant_discovery_endpoint": "https://login.microsoftonline.com/{tenant}/v2.0/.well-known/openid-configuration",
60+
"api-version": "1.1",
61+
"metadata": [
62+
{
63+
"preferred_network": "login.microsoftonline.com",
64+
"preferred_cache": "login.windows.net",
65+
"aliases": [
66+
"login.microsoftonline.com",
67+
"login.windows.net",
68+
"login.microsoft.com",
69+
"sts.windows.net"
70+
]
71+
}
72+
]
73+
}`, "{tenant}", tenant))
74+
}
75+
76+
func tenantMetadata(tenant string) []byte {
77+
return []byte(strings.ReplaceAll(`{
78+
"token_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token",
79+
"token_endpoint_auth_methods_supported": [
80+
"client_secret_post",
81+
"private_key_jwt",
82+
"client_secret_basic"
83+
],
84+
"jwks_uri": "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys",
85+
"response_modes_supported": [
86+
"query",
87+
"fragment",
88+
"form_post"
89+
],
90+
"subject_types_supported": [
91+
"pairwise"
92+
],
93+
"id_token_signing_alg_values_supported": [
94+
"RS256"
95+
],
96+
"response_types_supported": [
97+
"code",
98+
"id_token",
99+
"code id_token",
100+
"id_token token"
101+
],
102+
"scopes_supported": [
103+
"openid",
104+
"profile",
105+
"email",
106+
"offline_access"
107+
],
108+
"issuer": "https://login.microsoftonline.com/{tenant}/v2.0",
109+
"request_uri_parameter_supported": false,
110+
"userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo",
111+
"authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize",
112+
"device_authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/devicecode",
113+
"http_logout_supported": true,
114+
"frontchannel_logout_supported": true,
115+
"end_session_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/logout",
116+
"claims_supported": [
117+
"sub",
118+
"iss",
119+
"cloud_instance_name",
120+
"cloud_instance_host_name",
121+
"cloud_graph_host_name",
122+
"msgraph_host",
123+
"aud",
124+
"exp",
125+
"iat",
126+
"auth_time",
127+
"acr",
128+
"nonce",
129+
"preferred_username",
130+
"name",
131+
"tid",
132+
"ver",
133+
"at_hash",
134+
"c_hash",
135+
"email"
136+
],
137+
"kerberos_endpoint": "https://login.microsoftonline.com/{tenant}/kerberos",
138+
"tenant_region_scope": "NA",
139+
"cloud_instance_name": "microsoftonline.com",
140+
"cloud_graph_host_name": "graph.windows.net",
141+
"msgraph_host": "graph.microsoft.com",
142+
"rbac_url": "https://pas.windows.net"
143+
}`, "{tenant}", tenant))
144+
}

0 commit comments

Comments
 (0)