diff --git a/extensions/cache/accessor/storage_test.go b/extensions/cache/accessor/storage_test.go new file mode 100755 index 0000000..88c66aa --- /dev/null +++ b/extensions/cache/accessor/storage_test.go @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +// TODO: add other platforms +//go:build windows +// +build windows + +package accessor + +import ( + "context" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +const msalextManualTest = "MSALEXT_MANUAL_TEST" + +var ( + ctx = context.Background() + + // the Windows implementation doesn't require user interaction + manualTests = runtime.GOOS == "windows" || os.Getenv(msalextManualTest) != "" +) + +func TestReadWrite(t *testing.T) { + if !manualTests { + t.Skipf("set %s to run this test", msalextManualTest) + } + for _, test := range []struct { + desc string + want []byte + }{ + {desc: "Test when no stored data exists"}, + {desc: "Test writing data then reading it", want: []byte("want")}, + } { + t.Run(test.desc, func(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + a, err := New(p) + require.NoError(t, err) + + if test.want != nil { + cp := make([]byte, len(test.want)) + copy(cp, test.want) + err = a.Write(ctx, cp) + require.NoError(t, err) + } + + actual, err := a.Read(ctx) + require.NoError(t, err) + require.Equal(t, test.want, actual) + }) + } +} diff --git a/extensions/cache/accessor/windows.go b/extensions/cache/accessor/windows.go new file mode 100644 index 0000000..954ce70 --- /dev/null +++ b/extensions/cache/accessor/windows.go @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +//go:build windows +// +build windows + +package accessor + +import ( + "context" + "errors" + "os" + "path/filepath" + "sync" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Storage stores data in a file encrypted by the Windows data protection API. +type Storage struct { + m *sync.RWMutex + p string +} + +// New is the constructor for Storage. "p" is the path to the file in which to store data. +func New(p string) (*Storage, error) { + return &Storage{m: &sync.RWMutex{}, p: p}, nil +} + +// Read returns data from the file. If the file doesn't exist, Read returns a nil slice and error. +func (s *Storage) Read(context.Context) ([]byte, error) { + s.m.RLock() + defer s.m.RUnlock() + + data, err := os.ReadFile(s.p) + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + if err != nil { + return nil, err + } + if len(data) > 0 { + data, err = dpapi(decrypt, data) + } + return data, err +} + +// Write stores data in the file, creating the file if it doesn't exist. +func (s *Storage) Write(ctx context.Context, data []byte) error { + s.m.Lock() + defer s.m.Unlock() + + data, err := dpapi(encrypt, data) + if err != nil { + return err + } + err = os.WriteFile(s.p, data, 0600) + if errors.Is(err, os.ErrNotExist) { + dir := filepath.Dir(s.p) + if err = os.MkdirAll(dir, 0700); err == nil { + err = os.WriteFile(s.p, data, 0600) + } + } + return err +} + +type operation int + +const ( + decrypt operation = iota + encrypt +) + +func dpapi(op operation, data []byte) (result []byte, err error) { + out := windows.DataBlob{} + defer func() { + if out.Data != nil { + _, e := windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data))) + // prefer returning DPAPI errors because they're more interesting than LocalFree errors + if e != nil && err == nil { + err = e + } + } + }() + in := windows.DataBlob{Data: &data[0], Size: uint32(len(data))} + switch op { + case decrypt: + // https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptunprotectdata + err = windows.CryptUnprotectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out) + case encrypt: + // https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptprotectdata + err = windows.CryptProtectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out) + default: + err = errors.New("invalid operation") + } + if err == nil { + result = make([]byte, out.Size) + copy(result, unsafe.Slice(out.Data, out.Size)) + } + return result, err +} diff --git a/extensions/cache/accessor/windows_test.go b/extensions/cache/accessor/windows_test.go new file mode 100644 index 0000000..0bbebe7 --- /dev/null +++ b/extensions/cache/accessor/windows_test.go @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +//go:build windows +// +build windows + +package accessor + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriteEncryption(t *testing.T) { + p := filepath.Join(t.TempDir(), t.Name()) + a, err := New(p) + require.NoError(t, err) + + data := []byte(`{"key":"value"}`) + require.NoError(t, json.Unmarshal(data, &struct{}{}), "test bug: data should unmarshal") + require.NoError(t, a.Write(ctx, data)) + + // Write should have encrypted data before writing it to the file + actual, err := os.ReadFile(p) + require.NoError(t, err) + require.NotEmpty(t, actual) + err = json.Unmarshal(actual, &struct{}{}) + require.Error(t, err, "Unmarshal should fail because the file's content, being encrypted, isn't JSON") +} diff --git a/go.mod b/go.mod index b5b0c11..9014856 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 github.com/stretchr/testify v1.8.2 + golang.org/x/sys v0.8.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c ) @@ -17,6 +18,5 @@ require ( github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sys v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1bcda28..526f8aa 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=