Skip to content

Commit 5ae5202

Browse files
authored
Windows storage accessor (#15)
1 parent 7f96225 commit 5ae5202

File tree

5 files changed

+195
-3
lines changed

5 files changed

+195
-3
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
// TODO: add other platforms
5+
//go:build windows
6+
// +build windows
7+
8+
package accessor
9+
10+
import (
11+
"context"
12+
"os"
13+
"path/filepath"
14+
"runtime"
15+
"testing"
16+
17+
"github.com/stretchr/testify/require"
18+
)
19+
20+
const msalextManualTest = "MSALEXT_MANUAL_TEST"
21+
22+
var (
23+
ctx = context.Background()
24+
25+
// the Windows implementation doesn't require user interaction
26+
manualTests = runtime.GOOS == "windows" || os.Getenv(msalextManualTest) != ""
27+
)
28+
29+
func TestReadWrite(t *testing.T) {
30+
if !manualTests {
31+
t.Skipf("set %s to run this test", msalextManualTest)
32+
}
33+
for _, test := range []struct {
34+
desc string
35+
want []byte
36+
}{
37+
{desc: "Test when no stored data exists"},
38+
{desc: "Test writing data then reading it", want: []byte("want")},
39+
} {
40+
t.Run(test.desc, func(t *testing.T) {
41+
p := filepath.Join(t.TempDir(), t.Name())
42+
a, err := New(p)
43+
require.NoError(t, err)
44+
45+
if test.want != nil {
46+
cp := make([]byte, len(test.want))
47+
copy(cp, test.want)
48+
err = a.Write(ctx, cp)
49+
require.NoError(t, err)
50+
}
51+
52+
actual, err := a.Read(ctx)
53+
require.NoError(t, err)
54+
require.Equal(t, test.want, actual)
55+
})
56+
}
57+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
//go:build windows
5+
// +build windows
6+
7+
package accessor
8+
9+
import (
10+
"context"
11+
"errors"
12+
"os"
13+
"path/filepath"
14+
"sync"
15+
"unsafe"
16+
17+
"golang.org/x/sys/windows"
18+
)
19+
20+
// Storage stores data in a file encrypted by the Windows data protection API.
21+
type Storage struct {
22+
m *sync.RWMutex
23+
p string
24+
}
25+
26+
// New is the constructor for Storage. "p" is the path to the file in which to store data.
27+
func New(p string) (*Storage, error) {
28+
return &Storage{m: &sync.RWMutex{}, p: p}, nil
29+
}
30+
31+
// Read returns data from the file. If the file doesn't exist, Read returns a nil slice and error.
32+
func (s *Storage) Read(context.Context) ([]byte, error) {
33+
s.m.RLock()
34+
defer s.m.RUnlock()
35+
36+
data, err := os.ReadFile(s.p)
37+
if errors.Is(err, os.ErrNotExist) {
38+
return nil, nil
39+
}
40+
if err != nil {
41+
return nil, err
42+
}
43+
if len(data) > 0 {
44+
data, err = dpapi(decrypt, data)
45+
}
46+
return data, err
47+
}
48+
49+
// Write stores data in the file, creating the file if it doesn't exist.
50+
func (s *Storage) Write(ctx context.Context, data []byte) error {
51+
s.m.Lock()
52+
defer s.m.Unlock()
53+
54+
data, err := dpapi(encrypt, data)
55+
if err != nil {
56+
return err
57+
}
58+
err = os.WriteFile(s.p, data, 0600)
59+
if errors.Is(err, os.ErrNotExist) {
60+
dir := filepath.Dir(s.p)
61+
if err = os.MkdirAll(dir, 0700); err == nil {
62+
err = os.WriteFile(s.p, data, 0600)
63+
}
64+
}
65+
return err
66+
}
67+
68+
type operation int
69+
70+
const (
71+
decrypt operation = iota
72+
encrypt
73+
)
74+
75+
func dpapi(op operation, data []byte) (result []byte, err error) {
76+
out := windows.DataBlob{}
77+
defer func() {
78+
if out.Data != nil {
79+
_, e := windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
80+
// prefer returning DPAPI errors because they're more interesting than LocalFree errors
81+
if e != nil && err == nil {
82+
err = e
83+
}
84+
}
85+
}()
86+
in := windows.DataBlob{Data: &data[0], Size: uint32(len(data))}
87+
switch op {
88+
case decrypt:
89+
// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptunprotectdata
90+
err = windows.CryptUnprotectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
91+
case encrypt:
92+
// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptprotectdata
93+
err = windows.CryptProtectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
94+
default:
95+
err = errors.New("invalid operation")
96+
}
97+
if err == nil {
98+
result = make([]byte, out.Size)
99+
copy(result, unsafe.Slice(out.Data, out.Size))
100+
}
101+
return result, err
102+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in the project root for license information.
3+
4+
//go:build windows
5+
// +build windows
6+
7+
package accessor
8+
9+
import (
10+
"encoding/json"
11+
"os"
12+
"path/filepath"
13+
"testing"
14+
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func TestWriteEncryption(t *testing.T) {
19+
p := filepath.Join(t.TempDir(), t.Name())
20+
a, err := New(p)
21+
require.NoError(t, err)
22+
23+
data := []byte(`{"key":"value"}`)
24+
require.NoError(t, json.Unmarshal(data, &struct{}{}), "test bug: data should unmarshal")
25+
require.NoError(t, a.Write(ctx, data))
26+
27+
// Write should have encrypted data before writing it to the file
28+
actual, err := os.ReadFile(p)
29+
require.NoError(t, err)
30+
require.NotEmpty(t, actual)
31+
err = json.Unmarshal(actual, &struct{}{})
32+
require.Error(t, err, "Unmarshal should fail because the file's content, being encrypted, isn't JSON")
33+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.18
55
require (
66
github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0
77
github.com/stretchr/testify v1.8.2
8+
golang.org/x/sys v0.8.0
89
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
910
)
1011

@@ -17,6 +18,5 @@ require (
1718
github.com/kylelemons/godebug v1.1.0 // indirect
1819
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
1920
github.com/pmezard/go-difflib v1.0.0 // indirect
20-
golang.org/x/sys v0.5.0 // indirect
2121
gopkg.in/yaml.v3 v3.0.1 // indirect
2222
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
2626
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
2727
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
2828
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
29-
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
30-
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
29+
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
30+
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
3131
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
3232
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
3333
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

0 commit comments

Comments
 (0)