diff --git a/extensions/cache/accessor/accessor.go b/extensions/cache/accessor/accessor.go new file mode 100644 index 0000000..6aa08d4 --- /dev/null +++ b/extensions/cache/accessor/accessor.go @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package accessor + +import "context" + +// Accessor accesses data storage. +type Accessor interface { + Read(context.Context) ([]byte, error) + Write(context.Context, []byte) error +} diff --git a/extensions/cache/accessor/file/file.go b/extensions/cache/accessor/file/file.go new file mode 100644 index 0000000..e9e3866 --- /dev/null +++ b/extensions/cache/accessor/file/file.go @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package file + +import ( + "context" + "errors" + "os" + "path/filepath" + "sync" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor" +) + +// Storage stores data in an unencrypted file. +type Storage struct { + m *sync.RWMutex + p string +} + +// New is the constructor for Storage. "p" is the path to the file in which to store data. +func New(p string) (*Storage, error) { + return &Storage{m: &sync.RWMutex{}, p: p}, nil +} + +// Read returns the file's content or, if the file doesn't exist, a nil slice and error. +func (s *Storage) Read(context.Context) ([]byte, error) { + s.m.RLock() + defer s.m.RUnlock() + b, err := os.ReadFile(s.p) + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return b, err +} + +// Write stores data in the file, overwriting any content, and creates the file if necessary. +func (s *Storage) Write(ctx context.Context, data []byte) error { + s.m.Lock() + defer s.m.Unlock() + err := os.WriteFile(s.p, data, 0600) + if errors.Is(err, os.ErrNotExist) { + dir := filepath.Dir(s.p) + if err = os.MkdirAll(dir, 0700); err == nil { + err = os.WriteFile(s.p, data, 0600) + } + } + return err +} + +var _ accessor.Accessor = (*Storage)(nil) diff --git a/extensions/cache/accessor/file/file_test.go b/extensions/cache/accessor/file/file_test.go new file mode 100644 index 0000000..76c5cee --- /dev/null +++ b/extensions/cache/accessor/file/file_test.go @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package file + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +func TestRead(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + a, err := New(p) + require.NoError(t, err) + + expected := []byte("expected") + require.NoError(t, os.WriteFile(p, expected, 0600)) + + actual, err := a.Read(ctx) + require.NoError(t, err) + require.Equal(t, expected, actual) +} + +func TestRoundTrip(t *testing.T) { + p := filepath.Join(t.TempDir(), "nonexistent", t.Name()) + a, err := New(p) + require.NoError(t, err) + + var expected []byte + for i := 0; i < 4; i++ { + actual, err := a.Read(ctx) + require.NoError(t, err) + require.Equal(t, expected, actual) + + expected = append(expected, byte(i)) + require.NoError(t, a.Write(ctx, expected)) + } +} + +func TestWrite(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + for _, create := range []bool{true, false} { + name := "file exists" + if create { + name = "new file" + } + t.Run(name, func(t *testing.T) { + if create { + f, err := os.OpenFile(p, os.O_CREATE|os.O_EXCL, 0600) + require.NoError(t, err) + require.NoError(t, f.Close()) + } + a, err := New(p) + require.NoError(t, err) + + expected := []byte("expected") + require.NoError(t, a.Write(ctx, expected)) + + actual, err := os.ReadFile(p) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) + } +} diff --git a/extensions/cache/cache.go b/extensions/cache/cache.go new file mode 100644 index 0000000..c6626d4 --- /dev/null +++ b/extensions/cache/cache.go @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package cache + +import ( + "context" + "errors" + "os" + "path/filepath" + "sync" + "time" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/internal/lock" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" +) + +var ( + // retryDelay lets tests prevent delays when faking errors in Replace + retryDelay = 10 * time.Millisecond + // timeout lets tests set the default amount of time allowed to read from the accessor + timeout = time.Second +) + +// locker helps tests fake Lock +type locker interface { + Lock(context.Context) error + Unlock() error +} + +// Cache caches authentication data in external storage, using a file lock to coordinate +// access to it with other processes. +type Cache struct { + // a provides read/write access to storage + a accessor.Accessor + // data is accessor's data as of the last sync + data []byte + // l coordinates with other processes + l locker + // m coordinates this process's goroutines + m *sync.Mutex + // sync is when this Cache last read from or wrote to a + sync time.Time + // ts is the path to a file used to timestamp Export and Replace operations + ts string +} + +// New is the constructor for Cache. "p" is the path to a file used to track when stored +// data changes. Export will create this file and any directories in its path which don't +// already exist. +func New(a accessor.Accessor, p string) (*Cache, error) { + lock, err := lock.New(p+".lockfile", retryDelay) + if err != nil { + return nil, err + } + return &Cache{a: a, l: lock, m: &sync.Mutex{}, ts: p}, err +} + +// Export writes the bytes marshaled by "m" to the accessor. +// MSAL clients call this method automatically. +func (c *Cache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) (err error) { + c.m.Lock() + defer c.m.Unlock() + + data, err := m.Marshal() + if err != nil { + return err + } + err = c.l.Lock(ctx) + if err != nil { + return err + } + defer func() { + e := c.l.Unlock() + if err == nil { + err = e + } + }() + if err = c.a.Write(ctx, data); err == nil { + // touch the timestamp file to record the time of this write; discard any + // error because this is just an optimization to avoid redundant reads + c.sync = time.Now() + if er := os.Chtimes(c.ts, c.sync, c.sync); errors.Is(er, os.ErrNotExist) { + if er = os.MkdirAll(filepath.Dir(c.ts), 0700); er == nil { + f, _ := os.OpenFile(c.ts, os.O_CREATE, 0600) + _ = f.Close() + } + } + c.data = data + } + return err +} + +// Replace reads bytes from the accessor and unmarshals them to "u". +// MSAL clients call this method automatically. +func (c *Cache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { + c.m.Lock() + defer c.m.Unlock() + + // If the timestamp file indicates cached data hasn't changed since we last read or wrote it, + // return c.data, which is the data as of that time. Discard any error from reading the timestamp + // because this is just an optimization to prevent unnecessary reads. If we don't know whether + // cached data has changed, we assume it has. + read := true + data := c.data + f, err := os.Stat(c.ts) + if err == nil { + mt := f.ModTime() + read = !mt.Equal(c.sync) + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + // Unmarshal the accessor's data, reading it first if needed. We don't acquire the file lock before + // reading from the accessor because it isn't strictly necessary and is relatively expensive. In the + // unlikely event that a read overlaps with a write and returns malformed data, Unmarshal will return + // an error and we'll try another read. + for { + if read { + data, err = c.a.Read(ctx) + if err != nil { + break + } + } + err = u.Unmarshal(data) + if err == nil { + break + } else if !read { + // c.data is apparently corrupt; Read from the accessor before trying again + read = true + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(retryDelay): + // Unmarshal error; try again + } + } + // Update the sync time only if we read from the accessor and unmarshaled its data. Otherwise + // the data hasn't changed since the last read/write, or reading failed and we'll try again on + // the next call. + if err == nil && read { + c.data = data + if f, err := os.Stat(c.ts); err == nil { + c.sync = f.ModTime() + } + } + return err +} + +var _ cache.ExportReplace = (*Cache)(nil) diff --git a/extensions/cache/cache_test.go b/extensions/cache/cache_test.go new file mode 100644 index 0000000..e4d0113 --- /dev/null +++ b/extensions/cache/cache_test.go @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package cache + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "sync" + "testing" + "time" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +// fakeExternalCache implements accessor.Accessor to fake a persistent cache +type fakeExternalCache struct { + data []byte + readCallback, writeCallback func() error +} + +func (a *fakeExternalCache) Read(context.Context) ([]byte, error) { + var err error + if a.readCallback != nil { + err = a.readCallback() + } + return a.data, err +} + +func (a *fakeExternalCache) Write(ctx context.Context, b []byte) error { + var err error + if a.writeCallback != nil { + err = a.writeCallback() + } + if err != nil { + return err + } + cp := make([]byte, len(b)) + copy(cp, b) + a.data = cp + return nil +} + +// fakeInternalCache implements cache.Un/Marshaler to fake an MSAL client's in-memory cache +type fakeInternalCache struct { + data []byte + marshalCallback, unmarshalCallback func() error +} + +func (t *fakeInternalCache) Marshal() ([]byte, error) { + var err error + if t.marshalCallback != nil { + err = t.marshalCallback() + } + return t.data, err +} + +func (t *fakeInternalCache) Unmarshal(b []byte) error { + var err error + if t.unmarshalCallback != nil { + err = t.unmarshalCallback() + } + cp := make([]byte, len(b)) + copy(cp, b) + t.data = cp + return err +} + +type fakeLock struct { + lockErr, unlockErr error +} + +func (l fakeLock) Lock(context.Context) error { + return l.lockErr +} + +func (l fakeLock) Unlock() error { + return l.unlockErr +} + +func TestExport(t *testing.T) { + ec := &fakeExternalCache{} + ic := &fakeInternalCache{} + p := filepath.Join(t.TempDir(), t.Name(), "ts") + c, err := New(ec, p) + require.NoError(t, err) + + // Export should write the in-memory cache to the accessor and touch the timestamp file + lastWrite := time.Time{} + touched := false + for i := 0; i < 3; i++ { + s := fmt.Sprint(i) + *ic = fakeInternalCache{data: []byte(s)} + err = c.Export(ctx, ic, cache.ExportHints{}) + require.NoError(t, err) + require.Equal(t, []byte(s), ec.data) + + f, err := os.Stat(p) + require.NoError(t, err) + mt := f.ModTime() + + // Two iterations of this loop can run within one unit of system time on Windows, leaving the + // modtime apparently unchanged even though Export updated it. On Windows we therefore skip + // the strict test, instead requiring only that the modtime change once during this loop. + if runtime.GOOS != "windows" { + require.NotEqual(t, lastWrite, mt, "Export didn't update the timestamp") + } + if mt != lastWrite { + touched = true + } + lastWrite = mt + } + require.True(t, touched, "Export didn't update the timestamp") +} + +func TestFilenameCompat(t *testing.T) { + // verify Cache uses the same lock file name as would e.g. the Python implementation + p := filepath.Join(t.TempDir(), t.Name()) + ec := fakeExternalCache{ + // Cache should hold the file lock while calling Write + writeCallback: func() error { + require.FileExists(t, p+".lockfile", "missing expected lock file") + return nil + }, + } + c, err := New(&ec, p) + require.NoError(t, err) + + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.NoError(t, err) +} + +func TestLockError(t *testing.T) { + c, err := New(&fakeExternalCache{}, filepath.Join(t.TempDir(), t.Name())) + require.NoError(t, err) + expected := errors.New("expected") + c.l = fakeLock{lockErr: expected} + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.EqualError(t, err, expected.Error()) +} + +func TestPreservesTimestampFileContent(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + expected := []byte("expected") + err := os.WriteFile(p, expected, 0600) + require.NoError(t, err) + + ec := fakeExternalCache{} + c, err := New(&ec, p) + require.NoError(t, err) + + ic := fakeInternalCache{data: []byte("data")} + err = c.Export(ctx, &ic, cache.ExportHints{}) + require.NoError(t, err) + require.Equal(t, ic.data, ec.data) + + actual, err := os.ReadFile(p) + require.NoError(t, err) + require.Equal(t, expected, actual, "Cache truncated, or wrote to, the timestamp file") +} + +func TestRace(t *testing.T) { + ic := fakeInternalCache{} + ec := fakeExternalCache{} + c, err := New(&ec, filepath.Join(t.TempDir(), t.Name())) + require.NoError(t, err) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + if !t.Failed() { + err := c.Replace(ctx, &ic, cache.ReplaceHints{}) + if err == nil { + err = c.Export(ctx, &ic, cache.ExportHints{}) + } + if err != nil { + t.Errorf("%d: %s", i, err) + } + } + }(i) + } + wg.Wait() +} + +func TestReplace(t *testing.T) { + ic := fakeInternalCache{} + ec := fakeExternalCache{} + p := filepath.Join(t.TempDir(), t.Name()) + + c, err := New(&ec, p) + require.NoError(t, err) + require.Empty(t, ic) + + // Replace should read data from the accessor (external cache) into the in-memory cache, observing the timestamp file + f, err := os.Create(p) + require.NoError(t, err) + require.NoError(t, f.Close()) + for i := uint8(0); i < 4; i++ { + ec.data = []byte{i} + err = c.Replace(ctx, &ic, cache.ReplaceHints{}) + require.NoError(t, err) + require.EqualValues(t, ec.data, ic.data) + // touch the timestamp file to indicate another accessor wrote data. Backdating ensures the + // timestamp changes between iterations even when one executes faster than file time resolution + tm := time.Now().Add(-time.Duration(i+1) * time.Second) + require.NoError(t, os.Chtimes(p, tm, tm)) + } + + // Replace should return in-memory data when the timestamp indicates no intervening write to the persistent cache + for i := 0; i < 4; i++ { + err = c.Replace(ctx, &ic, cache.ReplaceHints{}) + require.NoError(t, err) + // ec.data hasn't changed; ic.data shouldn't change either + require.EqualValues(t, ec.data, ic.data) + } +} + +func TestReplaceErrors(t *testing.T) { + realDelay := retryDelay + retryDelay = 0 + t.Cleanup(func() { retryDelay = realDelay }) + expected := errors.New("expected") + + t.Run("read", func(t *testing.T) { + ec := &fakeExternalCache{readCallback: func() error { + return expected + }} + p := filepath.Join(t.TempDir(), t.Name()) + c, err := New(ec, p) + require.NoError(t, err) + + err = c.Replace(ctx, &fakeInternalCache{}, cache.ReplaceHints{}) + require.Equal(t, expected, err) + }) + + for _, transient := range []bool{true, false} { + name := "unmarshal error" + if transient { + name = "transient " + name + } + t.Run(name, func(t *testing.T) { + tries := 0 + ic := fakeInternalCache{unmarshalCallback: func() error { + tries++ + if transient && tries > 1 { + return nil + } + return expected + }} + ec := &fakeExternalCache{} + + p := filepath.Join(t.TempDir(), t.Name()) + c, err := New(ec, p) + require.NoError(t, err) + + cx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + err = c.Replace(cx, &ic, cache.ReplaceHints{}) + // err should be nil if the unmarshaling error was transient, non-nil if it wasn't + require.Equal(t, transient, err == nil) + }) + } +} + +func TestUnlockError(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + a := fakeExternalCache{} + c, err := New(&a, p) + require.NoError(t, err) + + // Export should return an error from Unlock()... + unlockErr := errors.New("unlock error") + c.l = fakeLock{unlockErr: unlockErr} + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.Equal(t, unlockErr, err) + + // ...unless another of its calls returned an error + writeErr := errors.New("write error") + a.writeCallback = func() error { return writeErr } + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.Equal(t, writeErr, err) +} diff --git a/extensions/integration_test.go b/extensions/integration_test.go new file mode 100644 index 0000000..a1109ad --- /dev/null +++ b/extensions/integration_test.go @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package extensions + +import ( + "context" + "fmt" + "net/http" + "path/filepath" + "strings" + "sync" + "testing" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor/file" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +func TestConfidentialClient(t *testing.T) { + t.Parallel() + p := filepath.Join(t.TempDir(), t.Name()) + a, err := file.New(p) + require.NoError(t, err) + c, err := cache.New(a, p+".timestamp") + require.NoError(t, err) + cred, err := confidential.NewCredFromSecret("*") + require.NoError(t, err) + client, err := confidential.New( + "https://login.microsoftonline.com/tenant", "clientID", cred, confidential.WithCache(c), confidential.WithHTTPClient(&mockSTS{}), + ) + require.NoError(t, err) + + gr := 20 + wg := sync.WaitGroup{} + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + if t.Failed() { + return + } + s := fmt.Sprint(n) + ar, err := client.AcquireTokenByCredential(ctx, []string{s}) + switch { + case err != nil: + t.Error(err) + case ar.AccessToken != s: + t.Errorf("possible test bug: expected %q from STS, got %q", s, ar.AccessToken) + default: + ar, err = client.AcquireTokenSilent(ctx, []string{s}) + if err != nil { + t.Error(err) + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + }(i) + } + wg.Wait() + if t.Failed() { + return + } + + // cache should have an access token from each goroutine + lost := gr + for i := 0; i < gr; i++ { + s := fmt.Sprint(i) + ar, err := client.AcquireTokenSilent(ctx, []string{s}) + if err == nil { + lost-- + if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + } + require.Equal(t, 0, lost, "lost %d/%d tokens", lost, gr) +} + +func TestPublicClient(t *testing.T) { + t.Parallel() + p := filepath.Join(t.TempDir(), t.Name()) + a, err := file.New(p) + require.NoError(t, err) + c, err := cache.New(a, p+".timestamp") + require.NoError(t, err) + sts := mockSTS{} + client, err := public.New("clientID", public.WithCache(c), public.WithHTTPClient(&sts)) + require.NoError(t, err) + + gr := 20 + wg := sync.WaitGroup{} + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + if t.Failed() { + return + } + s := fmt.Sprint(n) + ar, err := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") + switch { + case err != nil: + t.Error(err) + case ar.AccessToken != s: + t.Errorf("possible test bug: expected %q from STS, got %q", s, ar.AccessToken) + default: + ar, err = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) + if err != nil { + t.Error(err) + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + }(i) + } + wg.Wait() + if t.Failed() { + return + } + + accounts, err := client.Accounts(ctx) + require.NoError(t, err) + require.Equal(t, gr, len(accounts), "should have a cached account for each goroutine") + + // Verify no access token cached above was lost due to a race. Silent auth should return a cached + // access token given any scope above. A token request during this loop indicates the client + // exchanged a refresh token to reacquire the access token it should have found in the cache. + lostATs, reqs := 0, 0 + sts.tokenRequestCallback = func(*http.Request) { reqs++ } + for _, a := range accounts { + s, _, found := strings.Cut(a.HomeAccountID, ".") + require.True(t, found, "unexpected home account ID %q", a.HomeAccountID) + ar, err := client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(a)) + if err != nil { + // the cache has no access token for the expected scope and no refresh token for the account + lostATs++ + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + require.Equal(t, 0, lostATs+reqs, "lost %d/%d access tokens", reqs, gr) + + // The cache has all the expected access tokens but may have lost refresh tokens, so we try silent + // auth again for each account, passing a new scope to force the client to use a refresh token. + lostRTs := 0 + for _, a := range accounts { + s := "novelscope" + ar, err := client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(a)) + if err != nil { + lostRTs++ + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + require.Equal(t, 0, lostRTs, "lost %d/%d refresh tokens", lostRTs, gr) +} diff --git a/extensions/mock_test.go b/extensions/mock_test.go new file mode 100644 index 0000000..9385c3b --- /dev/null +++ b/extensions/mock_test.go @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package extensions + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" +) + +// mockSTS returns mock Azure AD responses so tests don't have to account for MSAL metadata requests +type mockSTS struct { + tokenRequestCallback func(*http.Request) +} + +func (m *mockSTS) Do(req *http.Request) (*http.Response, error) { + res := http.Response{StatusCode: http.StatusOK} + switch s := strings.Split(req.URL.Path, "/"); s[len(s)-1] { + case "instance": + res.Body = io.NopCloser(bytes.NewReader(instanceMetadata("tenant"))) + case "openid-configuration": + res.Body = io.NopCloser(bytes.NewReader(tenantMetadata("tenant"))) + case "token": + if m.tokenRequestCallback != nil { + m.tokenRequestCallback(req) + } + if err := req.ParseForm(); err != nil { + return nil, err + } + scope := strings.Split(req.FormValue("scope"), " ")[0] + userinfo := "" + if upn := req.FormValue("username"); upn != "" { + clientinfo := base64.RawStdEncoding.EncodeToString([]byte(fmt.Sprintf(`{"uid":"%s","utid":"utid"}`, upn))) + userinfo = fmt.Sprintf(`, "client_info":"%s", "id_token":"x.e30", "refresh_token": "rt"`, clientinfo) + } + res.Body = io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"access_token": %q, "expires_in": 3600%s}`, scope, userinfo)))) + default: + // User realm metadata request paths look like "/common/UserRealm/user@domain". + // Matching on the UserRealm segment avoids having to know the UPN. + if s[len(s)-2] == "UserRealm" { + res.Body = io.NopCloser( + strings.NewReader(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`), + ) + } else { + panic("unexpected request " + req.URL.String()) + } + } + return &res, nil +} + +func (m *mockSTS) CloseIdleConnections() {} + +func instanceMetadata(tenant string) []byte { + return []byte(strings.ReplaceAll(`{ + "tenant_discovery_endpoint": "https://login.microsoftonline.com/{tenant}/v2.0/.well-known/openid-configuration", + "api-version": "1.1", + "metadata": [ + { + "preferred_network": "login.microsoftonline.com", + "preferred_cache": "login.windows.net", + "aliases": [ + "login.microsoftonline.com", + "login.windows.net", + "login.microsoft.com", + "sts.windows.net" + ] + } + ] + }`, "{tenant}", tenant)) +} + +func tenantMetadata(tenant string) []byte { + return []byte(strings.ReplaceAll(`{ + "token_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "private_key_jwt", + "client_secret_basic" + ], + "jwks_uri": "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys", + "response_modes_supported": [ + "query", + "fragment", + "form_post" + ], + "subject_types_supported": [ + "pairwise" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "response_types_supported": [ + "code", + "id_token", + "code id_token", + "id_token token" + ], + "scopes_supported": [ + "openid", + "profile", + "email", + "offline_access" + ], + "issuer": "https://login.microsoftonline.com/{tenant}/v2.0", + "request_uri_parameter_supported": false, + "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", + "authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize", + "device_authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/devicecode", + "http_logout_supported": true, + "frontchannel_logout_supported": true, + "end_session_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/logout", + "claims_supported": [ + "sub", + "iss", + "cloud_instance_name", + "cloud_instance_host_name", + "cloud_graph_host_name", + "msgraph_host", + "aud", + "exp", + "iat", + "auth_time", + "acr", + "nonce", + "preferred_username", + "name", + "tid", + "ver", + "at_hash", + "c_hash", + "email" + ], + "kerberos_endpoint": "https://login.microsoftonline.com/{tenant}/kerberos", + "tenant_region_scope": "NA", + "cloud_instance_name": "microsoftonline.com", + "cloud_graph_host_name": "graph.windows.net", + "msgraph_host": "graph.microsoft.com", + "rbac_url": "https://pas.windows.net" + }`, "{tenant}", tenant)) +} diff --git a/extensions/perf_test.go b/extensions/perf_test.go new file mode 100644 index 0000000..bf4cc86 --- /dev/null +++ b/extensions/perf_test.go @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package extensions + +import ( + "fmt" + "path/filepath" + "sync" + "testing" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor/file" + mcache "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/require" +) + +// this file benchmarks MSAL clients using Cache and file.Accessor + +func newCache(b *testing.B) *cache.Cache { + p := filepath.Join(b.TempDir(), b.Name()) + a, err := file.New(p) + require.NoError(b, err) + c, err := cache.New(a, p+".timestamp") + require.NoError(b, err) + return c +} + +func BenchmarkConfidentialClient(b *testing.B) { + for _, baseline := range []bool{false, true} { + name := "file accessor" + if baseline { + name = "no persistence" + } + b.Run(name, func(b *testing.B) { + var c mcache.ExportReplace + if !baseline { + c = newCache(b) + } + cred, err := confidential.NewCredFromSecret("*") + require.NoError(b, err) + client, err := confidential.New( + "https://login.microsoftonline.com/tenant", "ID", cred, confidential.WithCache(c), confidential.WithHTTPClient(&mockSTS{}), + ) + require.NoError(b, err) + + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := fmt.Sprint(n) + _, _ = client.AcquireTokenByCredential(ctx, []string{s}) + _, _ = client.AcquireTokenSilent(ctx, []string{s}) + }(i) + } + wg.Wait() + } + }) + } +} + +func BenchmarkPublicClient(b *testing.B) { + for _, baseline := range []bool{false, true} { + name := "file accessor" + if baseline { + name = "no persistence" + } + b.Run(name, func(b *testing.B) { + var c mcache.ExportReplace + if !baseline { + c = newCache(b) + } + client, err := public.New("clientID", public.WithCache(c), public.WithHTTPClient(&mockSTS{})) + require.NoError(b, err) + + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := fmt.Sprint(n) + ar, _ := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") + _, _ = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) + }(i) + } + wg.Wait() + } + }) + } +} diff --git a/go.mod b/go.mod index ed4d424..b5b0c11 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,20 @@ module github.com/AzureAD/microsoft-authentication-extensions-for-go go 1.18 require ( + github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 github.com/stretchr/testify v1.8.2 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v4 v4.4.3 // indirect + github.com/google/uuid v1.3.0 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/kr/text v0.1.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ab0a0eb..1bcda28 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,21 @@ +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -15,6 +25,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=