Skip to content

Commit 7f96225

Browse files
authored
Cache and accessor/file packages (#14)
1 parent 2cab7c3 commit 7f96225

File tree

10 files changed

+1002
-0
lines changed

10 files changed

+1002
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package accessor
5+
6+
import "context"
7+
8+
// Accessor accesses data storage.
9+
type Accessor interface {
10+
Read(context.Context) ([]byte, error)
11+
Write(context.Context, []byte) error
12+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package file
5+
6+
import (
7+
"context"
8+
"errors"
9+
"os"
10+
"path/filepath"
11+
"sync"
12+
13+
"github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor"
14+
)
15+
16+
// Storage stores data in an unencrypted file.
17+
type Storage struct {
18+
m *sync.RWMutex
19+
p string
20+
}
21+
22+
// New is the constructor for Storage. "p" is the path to the file in which to store data.
23+
func New(p string) (*Storage, error) {
24+
return &Storage{m: &sync.RWMutex{}, p: p}, nil
25+
}
26+
27+
// Read returns the file's content or, if the file doesn't exist, a nil slice and error.
28+
func (s *Storage) Read(context.Context) ([]byte, error) {
29+
s.m.RLock()
30+
defer s.m.RUnlock()
31+
b, err := os.ReadFile(s.p)
32+
if errors.Is(err, os.ErrNotExist) {
33+
return nil, nil
34+
}
35+
return b, err
36+
}
37+
38+
// Write stores data in the file, overwriting any content, and creates the file if necessary.
39+
func (s *Storage) Write(ctx context.Context, data []byte) error {
40+
s.m.Lock()
41+
defer s.m.Unlock()
42+
err := os.WriteFile(s.p, data, 0600)
43+
if errors.Is(err, os.ErrNotExist) {
44+
dir := filepath.Dir(s.p)
45+
if err = os.MkdirAll(dir, 0700); err == nil {
46+
err = os.WriteFile(s.p, data, 0600)
47+
}
48+
}
49+
return err
50+
}
51+
52+
var _ accessor.Accessor = (*Storage)(nil)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package file
5+
6+
import (
7+
"context"
8+
"os"
9+
"path/filepath"
10+
"testing"
11+
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
var ctx = context.Background()
16+
17+
func TestRead(t *testing.T) {
18+
p := filepath.Join(t.TempDir(), t.Name())
19+
a, err := New(p)
20+
require.NoError(t, err)
21+
22+
expected := []byte("expected")
23+
require.NoError(t, os.WriteFile(p, expected, 0600))
24+
25+
actual, err := a.Read(ctx)
26+
require.NoError(t, err)
27+
require.Equal(t, expected, actual)
28+
}
29+
30+
func TestRoundTrip(t *testing.T) {
31+
p := filepath.Join(t.TempDir(), "nonexistent", t.Name())
32+
a, err := New(p)
33+
require.NoError(t, err)
34+
35+
var expected []byte
36+
for i := 0; i < 4; i++ {
37+
actual, err := a.Read(ctx)
38+
require.NoError(t, err)
39+
require.Equal(t, expected, actual)
40+
41+
expected = append(expected, byte(i))
42+
require.NoError(t, a.Write(ctx, expected))
43+
}
44+
}
45+
46+
func TestWrite(t *testing.T) {
47+
p := filepath.Join(t.TempDir(), t.Name())
48+
for _, create := range []bool{true, false} {
49+
name := "file exists"
50+
if create {
51+
name = "new file"
52+
}
53+
t.Run(name, func(t *testing.T) {
54+
if create {
55+
f, err := os.OpenFile(p, os.O_CREATE|os.O_EXCL, 0600)
56+
require.NoError(t, err)
57+
require.NoError(t, f.Close())
58+
}
59+
a, err := New(p)
60+
require.NoError(t, err)
61+
62+
expected := []byte("expected")
63+
require.NoError(t, a.Write(ctx, expected))
64+
65+
actual, err := os.ReadFile(p)
66+
require.NoError(t, err)
67+
require.Equal(t, expected, actual)
68+
})
69+
}
70+
}

extensions/cache/cache.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package cache
5+
6+
import (
7+
"context"
8+
"errors"
9+
"os"
10+
"path/filepath"
11+
"sync"
12+
"time"
13+
14+
"github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/cache/accessor"
15+
"github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/internal/lock"
16+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
17+
)
18+
19+
var (
20+
// retryDelay lets tests prevent delays when faking errors in Replace
21+
retryDelay = 10 * time.Millisecond
22+
// timeout lets tests set the default amount of time allowed to read from the accessor
23+
timeout = time.Second
24+
)
25+
26+
// locker helps tests fake Lock
27+
type locker interface {
28+
Lock(context.Context) error
29+
Unlock() error
30+
}
31+
32+
// Cache caches authentication data in external storage, using a file lock to coordinate
33+
// access to it with other processes.
34+
type Cache struct {
35+
// a provides read/write access to storage
36+
a accessor.Accessor
37+
// data is accessor's data as of the last sync
38+
data []byte
39+
// l coordinates with other processes
40+
l locker
41+
// m coordinates this process's goroutines
42+
m *sync.Mutex
43+
// sync is when this Cache last read from or wrote to a
44+
sync time.Time
45+
// ts is the path to a file used to timestamp Export and Replace operations
46+
ts string
47+
}
48+
49+
// New is the constructor for Cache. "p" is the path to a file used to track when stored
50+
// data changes. Export will create this file and any directories in its path which don't
51+
// already exist.
52+
func New(a accessor.Accessor, p string) (*Cache, error) {
53+
lock, err := lock.New(p+".lockfile", retryDelay)
54+
if err != nil {
55+
return nil, err
56+
}
57+
return &Cache{a: a, l: lock, m: &sync.Mutex{}, ts: p}, err
58+
}
59+
60+
// Export writes the bytes marshaled by "m" to the accessor.
61+
// MSAL clients call this method automatically.
62+
func (c *Cache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) (err error) {
63+
c.m.Lock()
64+
defer c.m.Unlock()
65+
66+
data, err := m.Marshal()
67+
if err != nil {
68+
return err
69+
}
70+
err = c.l.Lock(ctx)
71+
if err != nil {
72+
return err
73+
}
74+
defer func() {
75+
e := c.l.Unlock()
76+
if err == nil {
77+
err = e
78+
}
79+
}()
80+
if err = c.a.Write(ctx, data); err == nil {
81+
// touch the timestamp file to record the time of this write; discard any
82+
// error because this is just an optimization to avoid redundant reads
83+
c.sync = time.Now()
84+
if er := os.Chtimes(c.ts, c.sync, c.sync); errors.Is(er, os.ErrNotExist) {
85+
if er = os.MkdirAll(filepath.Dir(c.ts), 0700); er == nil {
86+
f, _ := os.OpenFile(c.ts, os.O_CREATE, 0600)
87+
_ = f.Close()
88+
}
89+
}
90+
c.data = data
91+
}
92+
return err
93+
}
94+
95+
// Replace reads bytes from the accessor and unmarshals them to "u".
96+
// MSAL clients call this method automatically.
97+
func (c *Cache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error {
98+
c.m.Lock()
99+
defer c.m.Unlock()
100+
101+
// If the timestamp file indicates cached data hasn't changed since we last read or wrote it,
102+
// return c.data, which is the data as of that time. Discard any error from reading the timestamp
103+
// because this is just an optimization to prevent unnecessary reads. If we don't know whether
104+
// cached data has changed, we assume it has.
105+
read := true
106+
data := c.data
107+
f, err := os.Stat(c.ts)
108+
if err == nil {
109+
mt := f.ModTime()
110+
read = !mt.Equal(c.sync)
111+
}
112+
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
113+
var cancel context.CancelFunc
114+
ctx, cancel = context.WithTimeout(ctx, timeout)
115+
defer cancel()
116+
}
117+
// Unmarshal the accessor's data, reading it first if needed. We don't acquire the file lock before
118+
// reading from the accessor because it isn't strictly necessary and is relatively expensive. In the
119+
// unlikely event that a read overlaps with a write and returns malformed data, Unmarshal will return
120+
// an error and we'll try another read.
121+
for {
122+
if read {
123+
data, err = c.a.Read(ctx)
124+
if err != nil {
125+
break
126+
}
127+
}
128+
err = u.Unmarshal(data)
129+
if err == nil {
130+
break
131+
} else if !read {
132+
// c.data is apparently corrupt; Read from the accessor before trying again
133+
read = true
134+
}
135+
select {
136+
case <-ctx.Done():
137+
return ctx.Err()
138+
case <-time.After(retryDelay):
139+
// Unmarshal error; try again
140+
}
141+
}
142+
// Update the sync time only if we read from the accessor and unmarshaled its data. Otherwise
143+
// the data hasn't changed since the last read/write, or reading failed and we'll try again on
144+
// the next call.
145+
if err == nil && read {
146+
c.data = data
147+
if f, err := os.Stat(c.ts); err == nil {
148+
c.sync = f.ModTime()
149+
}
150+
}
151+
return err
152+
}
153+
154+
var _ cache.ExportReplace = (*Cache)(nil)

0 commit comments

Comments
 (0)