Skip to content

Commit adcf691

Browse files
committed
Windows storage accessor
1 parent 7f96225 commit adcf691

File tree

5 files changed

+228
-3
lines changed

5 files changed

+228
-3
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
// TODO: add other platforms
5+
//go:build windows
6+
// +build windows
7+
8+
package accessor
9+
10+
import (
11+
"context"
12+
"os"
13+
"path/filepath"
14+
"runtime"
15+
"sync"
16+
"testing"
17+
18+
"github.com/stretchr/testify/require"
19+
)
20+
21+
const msalextManualTest = "MSALEXT_MANUAL_TEST"
22+
23+
var (
24+
ctx = context.Background()
25+
26+
// the Windows implementation doesn't require user interaction
27+
runTests = runtime.GOOS == "windows" || os.Getenv(msalextManualTest) != ""
28+
)
29+
30+
func TestRace(t *testing.T) {
31+
if !runTests {
32+
t.Skipf("set %s to run this test", msalextManualTest)
33+
}
34+
p := filepath.Join(t.TempDir(), t.Name())
35+
a, err := New(p)
36+
require.NoError(t, err)
37+
38+
actual, err := a.Read(ctx)
39+
require.NoError(t, err)
40+
require.Empty(t, actual)
41+
42+
expected := "expected"
43+
wg := sync.WaitGroup{}
44+
for i := 0; i < 20; i++ {
45+
wg.Add(1)
46+
go func() {
47+
defer wg.Done()
48+
if !t.Failed() {
49+
actual := []byte{}
50+
err := a.Write(ctx, []byte(expected))
51+
if err == nil {
52+
actual, err = a.Read(ctx)
53+
}
54+
if err != nil {
55+
t.Error(err)
56+
} else if a := string(actual); a != expected {
57+
t.Errorf("expected %q, got %q", expected, a)
58+
}
59+
}
60+
}()
61+
}
62+
wg.Wait()
63+
}
64+
65+
func TestRoundTrip(t *testing.T) {
66+
if !runTests {
67+
t.Skipf("set %s to run this test", msalextManualTest)
68+
}
69+
p := filepath.Join(t.TempDir(), t.Name())
70+
a, err := New(p)
71+
require.NoError(t, err)
72+
73+
actual, err := a.Read(ctx)
74+
require.NoError(t, err)
75+
require.Empty(t, actual)
76+
77+
expected := []byte("expected")
78+
err = a.Write(ctx, expected)
79+
require.NoError(t, err)
80+
81+
actual, err = a.Read(ctx)
82+
require.NoError(t, err)
83+
require.Equal(t, expected, actual)
84+
}

extensions/accessor/windows.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
//go:build windows
5+
// +build windows
6+
7+
package accessor
8+
9+
import (
10+
"context"
11+
"errors"
12+
"math"
13+
"os"
14+
"path/filepath"
15+
"sync"
16+
"unsafe"
17+
18+
"golang.org/x/sys/windows"
19+
)
20+
21+
// Storage stores data in a file encrypted by the Windows data protection API.
22+
type Storage struct {
23+
m *sync.RWMutex
24+
p string
25+
}
26+
27+
// New is the constructor for Storage. "p" is the path to the file in which to store data.
28+
func New(p string) (*Storage, error) {
29+
return &Storage{m: &sync.RWMutex{}, p: p}, nil
30+
}
31+
32+
// Read returns data from the file. If the file doesn't exist, Read returns a nil slice and error.
33+
func (s *Storage) Read(context.Context) ([]byte, error) {
34+
s.m.RLock()
35+
defer s.m.RUnlock()
36+
37+
data, err := os.ReadFile(s.p)
38+
if errors.Is(err, os.ErrNotExist) {
39+
return nil, nil
40+
}
41+
if err != nil {
42+
return nil, err
43+
}
44+
if len(data) > 0 {
45+
data, err = dpapi(decrypt, data)
46+
}
47+
return data, err
48+
}
49+
50+
// Write stores data in the file, creating the file if it doesn't exist.
51+
func (s *Storage) Write(ctx context.Context, data []byte) error {
52+
s.m.Lock()
53+
defer s.m.Unlock()
54+
55+
data, err := dpapi(encrypt, data)
56+
if err != nil {
57+
return err
58+
}
59+
err = os.WriteFile(s.p, data, 0600)
60+
if errors.Is(err, os.ErrNotExist) {
61+
dir := filepath.Dir(s.p)
62+
if err = os.MkdirAll(dir, 0700); err == nil {
63+
err = os.WriteFile(s.p, data, 0600)
64+
}
65+
}
66+
return err
67+
}
68+
69+
type operation int
70+
71+
const (
72+
decrypt operation = iota
73+
encrypt
74+
)
75+
76+
func dpapi(op operation, data []byte) (result []byte, err error) {
77+
out := windows.DataBlob{}
78+
defer func() {
79+
if out.Data != nil {
80+
_, e := windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
81+
// prefer returning DPAPI errors because they're more interesting than LocalFree errors
82+
if e != nil && err == nil {
83+
err = e
84+
}
85+
}
86+
}()
87+
in := windows.DataBlob{Data: &data[0], Size: uint32(len(data))}
88+
switch op {
89+
case decrypt:
90+
// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptunprotectdata
91+
err = windows.CryptUnprotectData(&in, nil, nil, 0, nil, 1, &out)
92+
case encrypt:
93+
// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptprotectdata
94+
err = windows.CryptProtectData(&in, nil, nil, 0, nil, 1, &out)
95+
default:
96+
err = errors.New("invalid operation")
97+
}
98+
if err == nil {
99+
// cast out.Data to a pointer to an arbitrarily long array, then slice the array and copy out.Size bytes from the
100+
// slice to result. This avoids allocating memory for a throwaway buffer but imposes a max size on the data because
101+
// the fictive array backing the slice can't be larger than the address space or the maximum value of an int. Those
102+
// values vary by platform, so the array size here is a compromise for 32-bit systems and allows ~2 GB of data.
103+
result = make([]byte, out.Size)
104+
source := (*[math.MaxInt32 - 1]byte)(unsafe.Pointer(out.Data))[:]
105+
copy(result, source)
106+
}
107+
return result, err
108+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
//go:build windows
5+
// +build windows
6+
7+
package accessor
8+
9+
import (
10+
"encoding/json"
11+
"os"
12+
"path/filepath"
13+
"testing"
14+
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func TestWriteEncryption(t *testing.T) {
19+
p := filepath.Join(t.TempDir(), t.Name())
20+
a, err := New(p)
21+
require.NoError(t, err)
22+
23+
data := []byte(`{"key":"value"}`)
24+
require.NoError(t, json.Unmarshal(data, &struct{}{}), "test bug: data should unmarshal")
25+
require.NoError(t, a.Write(ctx, data))
26+
27+
// Write should have encrypted data before writing it to the file
28+
actual, err := os.ReadFile(p)
29+
require.NoError(t, err)
30+
require.NotEmpty(t, actual)
31+
err = json.Unmarshal(actual, &struct{}{})
32+
require.Error(t, err, "Unmarshal should fail because the file's content, being encrypted, isn't JSON")
33+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.18
55
require (
66
github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0
77
github.com/stretchr/testify v1.8.2
8+
golang.org/x/sys v0.8.0
89
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
910
)
1011

@@ -17,6 +18,5 @@ require (
1718
github.com/kylelemons/godebug v1.1.0 // indirect
1819
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
1920
github.com/pmezard/go-difflib v1.0.0 // indirect
20-
golang.org/x/sys v0.5.0 // indirect
2121
gopkg.in/yaml.v3 v3.0.1 // indirect
2222
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
2626
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
2727
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
2828
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
29-
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
30-
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
29+
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
30+
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
3131
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
3232
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
3333
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

0 commit comments

Comments
 (0)