From 61f44614bb4c0b2a5670ef6290c05ac146f1a9b2 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 3 May 2023 13:28:07 -0700 Subject: [PATCH 1/7] require MSAL --- go.mod | 6 ++++++ go.sum | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/go.mod b/go.mod index ed4d424..b5b0c11 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,20 @@ module github.com/AzureAD/microsoft-authentication-extensions-for-go go 1.18 require ( + github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 github.com/stretchr/testify v1.8.2 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v4 v4.4.3 // indirect + github.com/google/uuid v1.3.0 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/kr/text v0.1.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ab0a0eb..1bcda28 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,21 @@ +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -15,6 +25,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= From e4d150014ce1a2f61e32f1475a468ac06b21a269 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 20 Apr 2023 12:23:14 -0700 Subject: [PATCH 2/7] Accessor reads/writes storage --- extensions/accessor/accessor.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 extensions/accessor/accessor.go diff --git a/extensions/accessor/accessor.go b/extensions/accessor/accessor.go new file mode 100644 index 0000000..6aa08d4 --- /dev/null +++ b/extensions/accessor/accessor.go @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package accessor + +import "context" + +// Accessor accesses data storage. +type Accessor interface { + Read(context.Context) ([]byte, error) + Write(context.Context, []byte) error +} From 991806ded0e9207c1e41f6f63d5b068e451422ba Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 20 Apr 2023 13:46:55 -0700 Subject: [PATCH 3/7] file storage --- extensions/accessor/file/file.go | 52 ++++++++++++++++++++ extensions/accessor/file/file_test.go | 70 +++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 extensions/accessor/file/file.go create mode 100644 extensions/accessor/file/file_test.go diff --git a/extensions/accessor/file/file.go b/extensions/accessor/file/file.go new file mode 100644 index 0000000..a57c814 --- /dev/null +++ b/extensions/accessor/file/file.go @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package file + +import ( + "context" + "errors" + "os" + "path/filepath" + "sync" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor" +) + +// Storage stores data in an unencrypted file. +type Storage struct { + m *sync.RWMutex + p string +} + +// New is the constructor for Storage. "p" is the path to the file in which to store data. +func New(p string) (*Storage, error) { + return &Storage{m: &sync.RWMutex{}, p: p}, nil +} + +// Read returns the file's content or, if the file doesn't exist, a nil slice and error. +func (s *Storage) Read(context.Context) ([]byte, error) { + s.m.RLock() + defer s.m.RUnlock() + b, err := os.ReadFile(s.p) + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return b, err +} + +// Write stores data in the file, overwriting any content, and creates the file if necessary. +func (s *Storage) Write(ctx context.Context, data []byte) error { + s.m.Lock() + defer s.m.Unlock() + err := os.WriteFile(s.p, data, 0600) + if errors.Is(err, os.ErrNotExist) { + dir := filepath.Dir(s.p) + if err = os.MkdirAll(dir, 0700); err == nil { + err = os.WriteFile(s.p, data, 0600) + } + } + return err +} + +var _ accessor.Accessor = (*Storage)(nil) diff --git a/extensions/accessor/file/file_test.go b/extensions/accessor/file/file_test.go new file mode 100644 index 0000000..76c5cee --- /dev/null +++ b/extensions/accessor/file/file_test.go @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package file + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +func TestRead(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + a, err := New(p) + require.NoError(t, err) + + expected := []byte("expected") + require.NoError(t, os.WriteFile(p, expected, 0600)) + + actual, err := a.Read(ctx) + require.NoError(t, err) + require.Equal(t, expected, actual) +} + +func TestRoundTrip(t *testing.T) { + p := filepath.Join(t.TempDir(), "nonexistent", t.Name()) + a, err := New(p) + require.NoError(t, err) + + var expected []byte + for i := 0; i < 4; i++ { + actual, err := a.Read(ctx) + require.NoError(t, err) + require.Equal(t, expected, actual) + + expected = append(expected, byte(i)) + require.NoError(t, a.Write(ctx, expected)) + } +} + +func TestWrite(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + for _, create := range []bool{true, false} { + name := "file exists" + if create { + name = "new file" + } + t.Run(name, func(t *testing.T) { + if create { + f, err := os.OpenFile(p, os.O_CREATE|os.O_EXCL, 0600) + require.NoError(t, err) + require.NoError(t, f.Close()) + } + a, err := New(p) + require.NoError(t, err) + + expected := []byte("expected") + require.NoError(t, a.Write(ctx, expected)) + + actual, err := os.ReadFile(p) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) + } +} From 7a05cec9376e234f33620994e93b0d98779aeb4f Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 1 May 2023 14:44:03 -0700 Subject: [PATCH 4/7] cache package --- extensions/cache/cache.go | 154 +++++++++++++++++ extensions/cache/cache_test.go | 291 +++++++++++++++++++++++++++++++++ 2 files changed, 445 insertions(+) create mode 100644 extensions/cache/cache.go create mode 100644 extensions/cache/cache_test.go diff --git a/extensions/cache/cache.go b/extensions/cache/cache.go new file mode 100644 index 0000000..4b9a524 --- /dev/null +++ b/extensions/cache/cache.go @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package cache + +import ( + "context" + "errors" + "os" + "path/filepath" + "sync" + "time" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/internal/lock" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" +) + +var ( + // retryDelay lets tests prevent delays when faking errors in Replace + retryDelay = 10 * time.Millisecond + // timeout lets tests set the default amount of time allowed to read from the accessor + timeout = time.Second +) + +// locker helps tests fake Lock +type locker interface { + Lock(context.Context) error + Unlock() error +} + +// Cache caches authentication data in external storage, using a file lock to coordinate +// access to it with other processes. +type Cache struct { + // a provides read/write access to storage + a accessor.Accessor + // data is accessor's data as of the last sync + data []byte + // l coordinates with other processes + l locker + // m coordinates this process's goroutines + m *sync.Mutex + // sync is when this Cache last read from or wrote to a + sync time.Time + // ts is the path to a file used to timestamp Export and Replace operations + ts string +} + +// New is the constructor for Cache. "p" is the path to a file used to track when stored +// data changes. Export will create this file and any directories in its path which don't +// already exist. +func New(a accessor.Accessor, p string) (*Cache, error) { + lock, err := lock.New(p+".lockfile", retryDelay) + if err != nil { + return nil, err + } + return &Cache{a: a, l: lock, m: &sync.Mutex{}, ts: p}, err +} + +// Export writes the bytes marshaled by "m" to the accessor. +// MSAL clients call this method automatically. +func (c *Cache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) (err error) { + c.m.Lock() + defer c.m.Unlock() + + data, err := m.Marshal() + if err != nil { + return err + } + err = c.l.Lock(ctx) + if err != nil { + return err + } + defer func() { + e := c.l.Unlock() + if err == nil { + err = e + } + }() + if err = c.a.Write(ctx, data); err == nil { + // touch the timestamp file to record the time of this write; discard any + // error because this is just an optimization to avoid redundant reads + c.sync = time.Now() + if er := os.Chtimes(c.ts, c.sync, c.sync); errors.Is(er, os.ErrNotExist) { + if er = os.MkdirAll(filepath.Dir(c.ts), 0700); er == nil { + f, _ := os.OpenFile(c.ts, os.O_CREATE, 0600) + _ = f.Close() + } + } + c.data = data + } + return err +} + +// Replace reads bytes from the accessor and unmarshals them to "u". +// MSAL clients call this method automatically. +func (c *Cache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { + c.m.Lock() + defer c.m.Unlock() + + // If the timestamp file indicates cached data hasn't changed since we last read or wrote it, + // return c.data, which is the data as of that time. Discard any error from reading the timestamp + // because this is just an optimization to prevent unnecessary reads. If we don't know whether + // cached data has changed, we assume it has. + read := true + data := c.data + f, err := os.Stat(c.ts) + if err == nil { + mt := f.ModTime() + read = !mt.Equal(c.sync) + } + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + // Unmarshal the accessor's data, reading it first if needed. We don't acquire the file lock before + // reading from the accessor because it isn't strictly necessary and is relatively expensive. In the + // unlikely event that a read overlaps with a write and returns malformed data, Unmarshal will return + // an error and we'll try another read. + for { + if read { + data, err = c.a.Read(ctx) + if err != nil { + break + } + } + err = u.Unmarshal(data) + if err == nil { + break + } else if !read { + // c.data is apparently corrupt; Read from the accessor before trying again + read = true + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(retryDelay): + // Unmarshal error; try again + } + } + // Update the sync time only if we read from the accessor and unmarshaled its data. Otherwise + // the data hasn't changed since the last read/write, or reading failed and we'll try again on + // the next call. + if err == nil && read { + c.data = data + if f, err := os.Stat(c.ts); err == nil { + c.sync = f.ModTime() + } + } + return err +} + +var _ cache.ExportReplace = (*Cache)(nil) diff --git a/extensions/cache/cache_test.go b/extensions/cache/cache_test.go new file mode 100644 index 0000000..e4d0113 --- /dev/null +++ b/extensions/cache/cache_test.go @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package cache + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "sync" + "testing" + "time" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +// fakeExternalCache implements accessor.Accessor to fake a persistent cache +type fakeExternalCache struct { + data []byte + readCallback, writeCallback func() error +} + +func (a *fakeExternalCache) Read(context.Context) ([]byte, error) { + var err error + if a.readCallback != nil { + err = a.readCallback() + } + return a.data, err +} + +func (a *fakeExternalCache) Write(ctx context.Context, b []byte) error { + var err error + if a.writeCallback != nil { + err = a.writeCallback() + } + if err != nil { + return err + } + cp := make([]byte, len(b)) + copy(cp, b) + a.data = cp + return nil +} + +// fakeInternalCache implements cache.Un/Marshaler to fake an MSAL client's in-memory cache +type fakeInternalCache struct { + data []byte + marshalCallback, unmarshalCallback func() error +} + +func (t *fakeInternalCache) Marshal() ([]byte, error) { + var err error + if t.marshalCallback != nil { + err = t.marshalCallback() + } + return t.data, err +} + +func (t *fakeInternalCache) Unmarshal(b []byte) error { + var err error + if t.unmarshalCallback != nil { + err = t.unmarshalCallback() + } + cp := make([]byte, len(b)) + copy(cp, b) + t.data = cp + return err +} + +type fakeLock struct { + lockErr, unlockErr error +} + +func (l fakeLock) Lock(context.Context) error { + return l.lockErr +} + +func (l fakeLock) Unlock() error { + return l.unlockErr +} + +func TestExport(t *testing.T) { + ec := &fakeExternalCache{} + ic := &fakeInternalCache{} + p := filepath.Join(t.TempDir(), t.Name(), "ts") + c, err := New(ec, p) + require.NoError(t, err) + + // Export should write the in-memory cache to the accessor and touch the timestamp file + lastWrite := time.Time{} + touched := false + for i := 0; i < 3; i++ { + s := fmt.Sprint(i) + *ic = fakeInternalCache{data: []byte(s)} + err = c.Export(ctx, ic, cache.ExportHints{}) + require.NoError(t, err) + require.Equal(t, []byte(s), ec.data) + + f, err := os.Stat(p) + require.NoError(t, err) + mt := f.ModTime() + + // Two iterations of this loop can run within one unit of system time on Windows, leaving the + // modtime apparently unchanged even though Export updated it. On Windows we therefore skip + // the strict test, instead requiring only that the modtime change once during this loop. + if runtime.GOOS != "windows" { + require.NotEqual(t, lastWrite, mt, "Export didn't update the timestamp") + } + if mt != lastWrite { + touched = true + } + lastWrite = mt + } + require.True(t, touched, "Export didn't update the timestamp") +} + +func TestFilenameCompat(t *testing.T) { + // verify Cache uses the same lock file name as would e.g. the Python implementation + p := filepath.Join(t.TempDir(), t.Name()) + ec := fakeExternalCache{ + // Cache should hold the file lock while calling Write + writeCallback: func() error { + require.FileExists(t, p+".lockfile", "missing expected lock file") + return nil + }, + } + c, err := New(&ec, p) + require.NoError(t, err) + + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.NoError(t, err) +} + +func TestLockError(t *testing.T) { + c, err := New(&fakeExternalCache{}, filepath.Join(t.TempDir(), t.Name())) + require.NoError(t, err) + expected := errors.New("expected") + c.l = fakeLock{lockErr: expected} + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.EqualError(t, err, expected.Error()) +} + +func TestPreservesTimestampFileContent(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + expected := []byte("expected") + err := os.WriteFile(p, expected, 0600) + require.NoError(t, err) + + ec := fakeExternalCache{} + c, err := New(&ec, p) + require.NoError(t, err) + + ic := fakeInternalCache{data: []byte("data")} + err = c.Export(ctx, &ic, cache.ExportHints{}) + require.NoError(t, err) + require.Equal(t, ic.data, ec.data) + + actual, err := os.ReadFile(p) + require.NoError(t, err) + require.Equal(t, expected, actual, "Cache truncated, or wrote to, the timestamp file") +} + +func TestRace(t *testing.T) { + ic := fakeInternalCache{} + ec := fakeExternalCache{} + c, err := New(&ec, filepath.Join(t.TempDir(), t.Name())) + require.NoError(t, err) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + if !t.Failed() { + err := c.Replace(ctx, &ic, cache.ReplaceHints{}) + if err == nil { + err = c.Export(ctx, &ic, cache.ExportHints{}) + } + if err != nil { + t.Errorf("%d: %s", i, err) + } + } + }(i) + } + wg.Wait() +} + +func TestReplace(t *testing.T) { + ic := fakeInternalCache{} + ec := fakeExternalCache{} + p := filepath.Join(t.TempDir(), t.Name()) + + c, err := New(&ec, p) + require.NoError(t, err) + require.Empty(t, ic) + + // Replace should read data from the accessor (external cache) into the in-memory cache, observing the timestamp file + f, err := os.Create(p) + require.NoError(t, err) + require.NoError(t, f.Close()) + for i := uint8(0); i < 4; i++ { + ec.data = []byte{i} + err = c.Replace(ctx, &ic, cache.ReplaceHints{}) + require.NoError(t, err) + require.EqualValues(t, ec.data, ic.data) + // touch the timestamp file to indicate another accessor wrote data. Backdating ensures the + // timestamp changes between iterations even when one executes faster than file time resolution + tm := time.Now().Add(-time.Duration(i+1) * time.Second) + require.NoError(t, os.Chtimes(p, tm, tm)) + } + + // Replace should return in-memory data when the timestamp indicates no intervening write to the persistent cache + for i := 0; i < 4; i++ { + err = c.Replace(ctx, &ic, cache.ReplaceHints{}) + require.NoError(t, err) + // ec.data hasn't changed; ic.data shouldn't change either + require.EqualValues(t, ec.data, ic.data) + } +} + +func TestReplaceErrors(t *testing.T) { + realDelay := retryDelay + retryDelay = 0 + t.Cleanup(func() { retryDelay = realDelay }) + expected := errors.New("expected") + + t.Run("read", func(t *testing.T) { + ec := &fakeExternalCache{readCallback: func() error { + return expected + }} + p := filepath.Join(t.TempDir(), t.Name()) + c, err := New(ec, p) + require.NoError(t, err) + + err = c.Replace(ctx, &fakeInternalCache{}, cache.ReplaceHints{}) + require.Equal(t, expected, err) + }) + + for _, transient := range []bool{true, false} { + name := "unmarshal error" + if transient { + name = "transient " + name + } + t.Run(name, func(t *testing.T) { + tries := 0 + ic := fakeInternalCache{unmarshalCallback: func() error { + tries++ + if transient && tries > 1 { + return nil + } + return expected + }} + ec := &fakeExternalCache{} + + p := filepath.Join(t.TempDir(), t.Name()) + c, err := New(ec, p) + require.NoError(t, err) + + cx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + err = c.Replace(cx, &ic, cache.ReplaceHints{}) + // err should be nil if the unmarshaling error was transient, non-nil if it wasn't + require.Equal(t, transient, err == nil) + }) + } +} + +func TestUnlockError(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + a := fakeExternalCache{} + c, err := New(&a, p) + require.NoError(t, err) + + // Export should return an error from Unlock()... + unlockErr := errors.New("unlock error") + c.l = fakeLock{unlockErr: unlockErr} + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.Equal(t, unlockErr, err) + + // ...unless another of its calls returned an error + writeErr := errors.New("write error") + a.writeCallback = func() error { return writeErr } + err = c.Export(ctx, &fakeInternalCache{}, cache.ExportHints{}) + require.Equal(t, writeErr, err) +} From a45cd8c5e9cffe6a68f942f44e0567fc8c472347 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 1 May 2023 14:45:07 -0700 Subject: [PATCH 5/7] integration tests and benchmarks --- extensions/integration_test.go | 161 +++++++++++++++++++++++++++++++++ extensions/mock_test.go | 144 +++++++++++++++++++++++++++++ extensions/perf_test.go | 121 +++++++++++++++++++++++++ 3 files changed, 426 insertions(+) create mode 100644 extensions/integration_test.go create mode 100644 extensions/mock_test.go create mode 100644 extensions/perf_test.go diff --git a/extensions/integration_test.go b/extensions/integration_test.go new file mode 100644 index 0000000..d2df5e1 --- /dev/null +++ b/extensions/integration_test.go @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package extensions + +import ( + "context" + "fmt" + "net/http" + "path/filepath" + "strings" + "sync" + "testing" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor/file" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +func TestConfidentialClient(t *testing.T) { + t.Parallel() + p := filepath.Join(t.TempDir(), t.Name()) + a, err := file.New(p) + require.NoError(t, err) + c, err := cache.New(a, p+".timestamp") + require.NoError(t, err) + cred, err := confidential.NewCredFromSecret("*") + require.NoError(t, err) + client, err := confidential.New( + "https://login.microsoftonline.com/tenant", "clientID", cred, confidential.WithCache(c), confidential.WithHTTPClient(&mockSTS{}), + ) + require.NoError(t, err) + + gr := 20 + wg := sync.WaitGroup{} + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + if t.Failed() { + return + } + s := fmt.Sprint(n) + ar, err := client.AcquireTokenByCredential(ctx, []string{s}) + switch { + case err != nil: + t.Error(err) + case ar.AccessToken != s: + t.Errorf("possible test bug: expected %q from STS, got %q", s, ar.AccessToken) + default: + ar, err = client.AcquireTokenSilent(ctx, []string{s}) + if err != nil { + t.Error(err) + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + }(i) + } + wg.Wait() + if t.Failed() { + return + } + + // cache should have an access token from each goroutine + lost := gr + for i := 0; i < gr; i++ { + s := fmt.Sprint(i) + ar, err := client.AcquireTokenSilent(ctx, []string{s}) + if err == nil { + lost-- + if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + } + require.Equal(t, 0, lost, "lost %d/%d tokens", lost, gr) +} + +func TestPublicClient(t *testing.T) { + t.Parallel() + p := filepath.Join(t.TempDir(), t.Name()) + a, err := file.New(p) + require.NoError(t, err) + c, err := cache.New(a, p+".timestamp") + require.NoError(t, err) + sts := mockSTS{} + client, err := public.New("clientID", public.WithCache(c), public.WithHTTPClient(&sts)) + require.NoError(t, err) + + gr := 20 + wg := sync.WaitGroup{} + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + if t.Failed() { + return + } + s := fmt.Sprint(n) + ar, err := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") + switch { + case err != nil: + t.Error(err) + case ar.AccessToken != s: + t.Errorf("possible test bug: expected %q from STS, got %q", s, ar.AccessToken) + default: + ar, err = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) + if err != nil { + t.Error(err) + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + }(i) + } + wg.Wait() + if t.Failed() { + return + } + + accounts, err := client.Accounts(ctx) + require.NoError(t, err) + require.Equal(t, gr, len(accounts), "should have a cached account for each goroutine") + + // Verify no access token cached above was lost due to a race. Silent auth should return a cached + // access token given any scope above. A token request during this loop indicates the client + // exchanged a refresh token to reacquire the access token it should have found in the cache. + lostATs, reqs := 0, 0 + sts.tokenRequestCallback = func(*http.Request) { reqs++ } + for _, a := range accounts { + s, _, found := strings.Cut(a.HomeAccountID, ".") + require.True(t, found, "unexpected home account ID %q", a.HomeAccountID) + ar, err := client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(a)) + if err != nil { + // the cache has no access token for the expected scope and no refresh token for the account + lostATs++ + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + require.Equal(t, 0, lostATs+reqs, "lost %d/%d access tokens", reqs, gr) + + // The cache has all the expected access tokens but may have lost refresh tokens, so we try silent + // auth again for each account, passing a new scope to force the client to use a refresh token. + lostRTs := 0 + for _, a := range accounts { + s := "novelscope" + ar, err := client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(a)) + if err != nil { + lostRTs++ + } else if ar.AccessToken != s { + t.Errorf("possible cache corruption: expected %q, got %q", s, ar.AccessToken) + } + } + require.Equal(t, 0, lostRTs, "lost %d/%d refresh tokens", lostRTs, gr) +} diff --git a/extensions/mock_test.go b/extensions/mock_test.go new file mode 100644 index 0000000..9385c3b --- /dev/null +++ b/extensions/mock_test.go @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package extensions + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" +) + +// mockSTS returns mock Azure AD responses so tests don't have to account for MSAL metadata requests +type mockSTS struct { + tokenRequestCallback func(*http.Request) +} + +func (m *mockSTS) Do(req *http.Request) (*http.Response, error) { + res := http.Response{StatusCode: http.StatusOK} + switch s := strings.Split(req.URL.Path, "/"); s[len(s)-1] { + case "instance": + res.Body = io.NopCloser(bytes.NewReader(instanceMetadata("tenant"))) + case "openid-configuration": + res.Body = io.NopCloser(bytes.NewReader(tenantMetadata("tenant"))) + case "token": + if m.tokenRequestCallback != nil { + m.tokenRequestCallback(req) + } + if err := req.ParseForm(); err != nil { + return nil, err + } + scope := strings.Split(req.FormValue("scope"), " ")[0] + userinfo := "" + if upn := req.FormValue("username"); upn != "" { + clientinfo := base64.RawStdEncoding.EncodeToString([]byte(fmt.Sprintf(`{"uid":"%s","utid":"utid"}`, upn))) + userinfo = fmt.Sprintf(`, "client_info":"%s", "id_token":"x.e30", "refresh_token": "rt"`, clientinfo) + } + res.Body = io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"access_token": %q, "expires_in": 3600%s}`, scope, userinfo)))) + default: + // User realm metadata request paths look like "/common/UserRealm/user@domain". + // Matching on the UserRealm segment avoids having to know the UPN. + if s[len(s)-2] == "UserRealm" { + res.Body = io.NopCloser( + strings.NewReader(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`), + ) + } else { + panic("unexpected request " + req.URL.String()) + } + } + return &res, nil +} + +func (m *mockSTS) CloseIdleConnections() {} + +func instanceMetadata(tenant string) []byte { + return []byte(strings.ReplaceAll(`{ + "tenant_discovery_endpoint": "https://login.microsoftonline.com/{tenant}/v2.0/.well-known/openid-configuration", + "api-version": "1.1", + "metadata": [ + { + "preferred_network": "login.microsoftonline.com", + "preferred_cache": "login.windows.net", + "aliases": [ + "login.microsoftonline.com", + "login.windows.net", + "login.microsoft.com", + "sts.windows.net" + ] + } + ] + }`, "{tenant}", tenant)) +} + +func tenantMetadata(tenant string) []byte { + return []byte(strings.ReplaceAll(`{ + "token_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "private_key_jwt", + "client_secret_basic" + ], + "jwks_uri": "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys", + "response_modes_supported": [ + "query", + "fragment", + "form_post" + ], + "subject_types_supported": [ + "pairwise" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "response_types_supported": [ + "code", + "id_token", + "code id_token", + "id_token token" + ], + "scopes_supported": [ + "openid", + "profile", + "email", + "offline_access" + ], + "issuer": "https://login.microsoftonline.com/{tenant}/v2.0", + "request_uri_parameter_supported": false, + "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", + "authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize", + "device_authorization_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/devicecode", + "http_logout_supported": true, + "frontchannel_logout_supported": true, + "end_session_endpoint": "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/logout", + "claims_supported": [ + "sub", + "iss", + "cloud_instance_name", + "cloud_instance_host_name", + "cloud_graph_host_name", + "msgraph_host", + "aud", + "exp", + "iat", + "auth_time", + "acr", + "nonce", + "preferred_username", + "name", + "tid", + "ver", + "at_hash", + "c_hash", + "email" + ], + "kerberos_endpoint": "https://login.microsoftonline.com/{tenant}/kerberos", + "tenant_region_scope": "NA", + "cloud_instance_name": "microsoftonline.com", + "cloud_graph_host_name": "graph.windows.net", + "msgraph_host": "graph.microsoft.com", + "rbac_url": "https://pas.windows.net" + }`, "{tenant}", tenant)) +} diff --git a/extensions/perf_test.go b/extensions/perf_test.go new file mode 100644 index 0000000..62bb660 --- /dev/null +++ b/extensions/perf_test.go @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package extensions + +import ( + "fmt" + "path/filepath" + "sync" + "testing" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor/file" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/require" +) + +// this file benchmarks MSAL clients using Cache and file.Accessor + +func BenchmarkConfidentialClient(b *testing.B) { + p := filepath.Join(b.TempDir(), b.Name()) + a, err := file.New(p) + require.NoError(b, err) + c, err := cache.New(a, p+".timestamp") + require.NoError(b, err) + cred, err := confidential.NewCredFromSecret("*") + require.NoError(b, err) + client, err := confidential.New( + "https://login.microsoftonline.com/tenant", "clientID", cred, confidential.WithCache(c), confidential.WithHTTPClient(&mockSTS{}), + ) + require.NoError(b, err) + + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := fmt.Sprint(n) + _, _ = client.AcquireTokenByCredential(ctx, []string{s}) + _, _ = client.AcquireTokenSilent(ctx, []string{s}) + }(i) + } + wg.Wait() + } +} + +func BenchmarkConfidentialClient_NoPersistence(b *testing.B) { + cred, err := confidential.NewCredFromSecret("*") + require.NoError(b, err) + client, err := confidential.New("https://login.microsoftonline.com/tenant", "clientID", cred, confidential.WithHTTPClient(&mockSTS{})) + require.NoError(b, err) + + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := []string{fmt.Sprint(n)} + _, _ = client.AcquireTokenByCredential(ctx, s) + _, _ = client.AcquireTokenSilent(ctx, s) + }(i) + } + wg.Wait() + } +} + +func BenchmarkPublicClient(b *testing.B) { + p := filepath.Join(b.TempDir(), b.Name()) + a, err := file.New(p) + require.NoError(b, err) + c, err := cache.New(a, p+".timestamp") + require.NoError(b, err) + sts := mockSTS{} + client, err := public.New("clientID", public.WithCache(c), public.WithHTTPClient(&sts)) + require.NoError(b, err) + + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := fmt.Sprint(n) + ar, _ := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") + _, _ = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) + }(i) + } + wg.Wait() + } +} + +func BenchmarkPublicClient_NoPersistence(b *testing.B) { + sts := mockSTS{} + client, err := public.New("clientID", public.WithHTTPClient(&sts)) + require.NoError(b, err) + + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := fmt.Sprint(n) + ar, _ := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") + _, _ = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) + }(i) + } + wg.Wait() + } +} From 4f75c0c088c9be198090a5f419293e13905177c4 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 19 May 2023 09:48:13 -0700 Subject: [PATCH 6/7] extensions/accessor -> extensions/cache/accessor --- extensions/{ => cache}/accessor/accessor.go | 0 extensions/{ => cache}/accessor/file/file.go | 2 +- extensions/{ => cache}/accessor/file/file_test.go | 0 extensions/cache/cache.go | 2 +- extensions/integration_test.go | 2 +- extensions/perf_test.go | 2 +- 6 files changed, 4 insertions(+), 4 deletions(-) rename extensions/{ => cache}/accessor/accessor.go (100%) rename extensions/{ => cache}/accessor/file/file.go (97%) rename extensions/{ => cache}/accessor/file/file_test.go (100%) diff --git a/extensions/accessor/accessor.go b/extensions/cache/accessor/accessor.go similarity index 100% rename from extensions/accessor/accessor.go rename to extensions/cache/accessor/accessor.go diff --git a/extensions/accessor/file/file.go b/extensions/cache/accessor/file/file.go similarity index 97% rename from extensions/accessor/file/file.go rename to extensions/cache/accessor/file/file.go index a57c814..e9e3866 100644 --- a/extensions/accessor/file/file.go +++ b/extensions/cache/accessor/file/file.go @@ -10,7 +10,7 @@ import ( "path/filepath" "sync" - "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor" ) // Storage stores data in an unencrypted file. diff --git a/extensions/accessor/file/file_test.go b/extensions/cache/accessor/file/file_test.go similarity index 100% rename from extensions/accessor/file/file_test.go rename to extensions/cache/accessor/file/file_test.go diff --git a/extensions/cache/cache.go b/extensions/cache/cache.go index 4b9a524..c6626d4 100644 --- a/extensions/cache/cache.go +++ b/extensions/cache/cache.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor" "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/internal/lock" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" ) diff --git a/extensions/integration_test.go b/extensions/integration_test.go index d2df5e1..a1109ad 100644 --- a/extensions/integration_test.go +++ b/extensions/integration_test.go @@ -12,8 +12,8 @@ import ( "sync" "testing" - "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor/file" "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor/file" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" "github.com/stretchr/testify/require" diff --git a/extensions/perf_test.go b/extensions/perf_test.go index 62bb660..40ce4ac 100644 --- a/extensions/perf_test.go +++ b/extensions/perf_test.go @@ -9,8 +9,8 @@ import ( "sync" "testing" - "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor/file" "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor/file" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" "github.com/stretchr/testify/require" From 931666295399c8a5b6682a474e78baed748096ba Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 19 May 2023 10:25:34 -0700 Subject: [PATCH 7/7] refactor benchmarks --- extensions/perf_test.go | 148 +++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 85 deletions(-) diff --git a/extensions/perf_test.go b/extensions/perf_test.go index 40ce4ac..bf4cc86 100644 --- a/extensions/perf_test.go +++ b/extensions/perf_test.go @@ -11,6 +11,7 @@ import ( "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache" "github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor/file" + mcache "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" "github.com/stretchr/testify/require" @@ -18,104 +19,81 @@ import ( // this file benchmarks MSAL clients using Cache and file.Accessor -func BenchmarkConfidentialClient(b *testing.B) { +func newCache(b *testing.B) *cache.Cache { p := filepath.Join(b.TempDir(), b.Name()) a, err := file.New(p) require.NoError(b, err) c, err := cache.New(a, p+".timestamp") require.NoError(b, err) - cred, err := confidential.NewCredFromSecret("*") - require.NoError(b, err) - client, err := confidential.New( - "https://login.microsoftonline.com/tenant", "clientID", cred, confidential.WithCache(c), confidential.WithHTTPClient(&mockSTS{}), - ) - require.NoError(b, err) - - gr := 10 - wg := sync.WaitGroup{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - for i := 0; i < gr; i++ { - wg.Add(1) - go func(n int) { - defer wg.Done() - s := fmt.Sprint(n) - _, _ = client.AcquireTokenByCredential(ctx, []string{s}) - _, _ = client.AcquireTokenSilent(ctx, []string{s}) - }(i) - } - wg.Wait() - } + return c } -func BenchmarkConfidentialClient_NoPersistence(b *testing.B) { - cred, err := confidential.NewCredFromSecret("*") - require.NoError(b, err) - client, err := confidential.New("https://login.microsoftonline.com/tenant", "clientID", cred, confidential.WithHTTPClient(&mockSTS{})) - require.NoError(b, err) - - gr := 10 - wg := sync.WaitGroup{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - for i := 0; i < gr; i++ { - wg.Add(1) - go func(n int) { - defer wg.Done() - s := []string{fmt.Sprint(n)} - _, _ = client.AcquireTokenByCredential(ctx, s) - _, _ = client.AcquireTokenSilent(ctx, s) - }(i) +func BenchmarkConfidentialClient(b *testing.B) { + for _, baseline := range []bool{false, true} { + name := "file accessor" + if baseline { + name = "no persistence" } - wg.Wait() + b.Run(name, func(b *testing.B) { + var c mcache.ExportReplace + if !baseline { + c = newCache(b) + } + cred, err := confidential.NewCredFromSecret("*") + require.NoError(b, err) + client, err := confidential.New( + "https://login.microsoftonline.com/tenant", "ID", cred, confidential.WithCache(c), confidential.WithHTTPClient(&mockSTS{}), + ) + require.NoError(b, err) + + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := fmt.Sprint(n) + _, _ = client.AcquireTokenByCredential(ctx, []string{s}) + _, _ = client.AcquireTokenSilent(ctx, []string{s}) + }(i) + } + wg.Wait() + } + }) } } func BenchmarkPublicClient(b *testing.B) { - p := filepath.Join(b.TempDir(), b.Name()) - a, err := file.New(p) - require.NoError(b, err) - c, err := cache.New(a, p+".timestamp") - require.NoError(b, err) - sts := mockSTS{} - client, err := public.New("clientID", public.WithCache(c), public.WithHTTPClient(&sts)) - require.NoError(b, err) - - gr := 10 - wg := sync.WaitGroup{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - for i := 0; i < gr; i++ { - wg.Add(1) - go func(n int) { - defer wg.Done() - s := fmt.Sprint(n) - ar, _ := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") - _, _ = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) - }(i) + for _, baseline := range []bool{false, true} { + name := "file accessor" + if baseline { + name = "no persistence" } - wg.Wait() - } -} + b.Run(name, func(b *testing.B) { + var c mcache.ExportReplace + if !baseline { + c = newCache(b) + } + client, err := public.New("clientID", public.WithCache(c), public.WithHTTPClient(&mockSTS{})) + require.NoError(b, err) -func BenchmarkPublicClient_NoPersistence(b *testing.B) { - sts := mockSTS{} - client, err := public.New("clientID", public.WithHTTPClient(&sts)) - require.NoError(b, err) - - gr := 10 - wg := sync.WaitGroup{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - for i := 0; i < gr; i++ { - wg.Add(1) - go func(n int) { - defer wg.Done() - s := fmt.Sprint(n) - ar, _ := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") - _, _ = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) - }(i) - } - wg.Wait() + gr := 10 + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < gr; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + s := fmt.Sprint(n) + ar, _ := client.AcquireTokenByUsernamePassword(ctx, []string{s}, s, "password") + _, _ = client.AcquireTokenSilent(ctx, []string{s}, public.WithSilentAccount(ar.Account)) + }(i) + } + wg.Wait() + } + }) } }