Skip to content

Commit 4a24cf1

Browse files
committed
file storage
1 parent 997db81 commit 4a24cf1

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

extensions/accessor/file/file.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package file
5+
6+
import (
7+
"context"
8+
"errors"
9+
"os"
10+
"path/filepath"
11+
"sync"
12+
13+
"github.com/AzureAD/microsoft-authentication-extensions-for-go/extensions/accessor"
14+
)
15+
16+
// Storage stores data in an unencrypted file.
17+
type Storage struct {
18+
m *sync.RWMutex
19+
p string
20+
}
21+
22+
// New is the constructor for Storage. "p" is the path to the file in which to store data.
23+
func New(p string) (*Storage, error) {
24+
return &Storage{m: &sync.RWMutex{}, p: p}, nil
25+
}
26+
27+
// Read returns the file's content or, if the file doesn't exist, a nil slice and error.
28+
func (s *Storage) Read(context.Context) ([]byte, error) {
29+
s.m.RLock()
30+
defer s.m.RUnlock()
31+
b, err := os.ReadFile(s.p)
32+
if errors.Is(err, os.ErrNotExist) {
33+
return nil, nil
34+
}
35+
return b, err
36+
}
37+
38+
// Write stores data in the file, overwriting any content, and creates the file if necessary.
39+
func (s *Storage) Write(ctx context.Context, data []byte) error {
40+
s.m.Lock()
41+
defer s.m.Unlock()
42+
err := os.WriteFile(s.p, data, 0600)
43+
if errors.Is(err, os.ErrNotExist) {
44+
dir := filepath.Dir(s.p)
45+
if err = os.MkdirAll(dir, 0700); err == nil {
46+
err = os.WriteFile(s.p, data, 0600)
47+
}
48+
}
49+
return err
50+
}
51+
52+
var _ accessor.Accessor = (*Storage)(nil)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
package file
5+
6+
import (
7+
"context"
8+
"os"
9+
"path/filepath"
10+
"testing"
11+
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
var ctx = context.Background()
16+
17+
func TestRead(t *testing.T) {
18+
p := filepath.Join(t.TempDir(), t.Name())
19+
a, err := New(p)
20+
require.NoError(t, err)
21+
22+
expected := []byte("expected")
23+
require.NoError(t, os.WriteFile(p, expected, 0600))
24+
25+
actual, err := a.Read(ctx)
26+
require.NoError(t, err)
27+
require.Equal(t, expected, actual)
28+
}
29+
30+
func TestRoundTrip(t *testing.T) {
31+
p := filepath.Join(t.TempDir(), "nonexistent", t.Name())
32+
a, err := New(p)
33+
require.NoError(t, err)
34+
35+
var expected []byte
36+
for i := 0; i < 4; i++ {
37+
actual, err := a.Read(ctx)
38+
require.NoError(t, err)
39+
require.Equal(t, expected, actual)
40+
41+
expected = append(expected, byte(i))
42+
require.NoError(t, a.Write(ctx, expected))
43+
}
44+
}
45+
46+
func TestWrite(t *testing.T) {
47+
p := filepath.Join(t.TempDir(), t.Name())
48+
for _, create := range []bool{true, false} {
49+
name := "file exists"
50+
if create {
51+
name = "new file"
52+
}
53+
t.Run(name, func(t *testing.T) {
54+
if create {
55+
f, err := os.OpenFile(p, os.O_CREATE|os.O_EXCL, 0600)
56+
require.NoError(t, err)
57+
require.NoError(t, f.Close())
58+
}
59+
a, err := New(p)
60+
require.NoError(t, err)
61+
62+
expected := []byte("expected")
63+
require.NoError(t, a.Write(ctx, expected))
64+
65+
actual, err := os.ReadFile(p)
66+
require.NoError(t, err)
67+
require.Equal(t, expected, actual)
68+
})
69+
}
70+
}

0 commit comments

Comments
 (0)