Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions extensions/cache/accessor/storage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

// TODO: add other platforms
//go:build windows
// +build windows

package accessor

import (
"context"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/stretchr/testify/require"
)

const msalextManualTest = "MSALEXT_MANUAL_TEST"

var (
ctx = context.Background()

// the Windows implementation doesn't require user interaction
manualTests = runtime.GOOS == "windows" || os.Getenv(msalextManualTest) != ""
)

func TestReadWrite(t *testing.T) {
if !manualTests {
t.Skipf("set %s to run this test", msalextManualTest)
}
for _, test := range []struct {
desc string
want []byte
}{
{desc: "Test when no stored data exists"},
{desc: "Test writing data then reading it", want: []byte("want")},
} {
t.Run(test.desc, func(t *testing.T) {
p := filepath.Join(t.TempDir(), t.Name())
a, err := New(p)
require.NoError(t, err)

if test.want != nil {
cp := make([]byte, len(test.want))
copy(cp, test.want)
err = a.Write(ctx, cp)
require.NoError(t, err)
}

actual, err := a.Read(ctx)
require.NoError(t, err)
require.Equal(t, test.want, actual)
})
}
}
102 changes: 102 additions & 0 deletions extensions/cache/accessor/windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

//go:build windows
// +build windows

package accessor

import (
"context"
"errors"
"os"
"path/filepath"
"sync"
"unsafe"

"golang.org/x/sys/windows"
)

// Storage stores data in a file encrypted by the Windows data protection API.
type Storage struct {
m *sync.RWMutex
p string
}

// New is the constructor for Storage. "p" is the path to the file in which to store data.
func New(p string) (*Storage, error) {
return &Storage{m: &sync.RWMutex{}, p: p}, nil
}

// Read returns data from the file. If the file doesn't exist, Read returns a nil slice and error.
func (s *Storage) Read(context.Context) ([]byte, error) {
s.m.RLock()
defer s.m.RUnlock()

data, err := os.ReadFile(s.p)
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
if err != nil {
return nil, err
}
if len(data) > 0 {
data, err = dpapi(decrypt, data)
}
return data, err
}

// Write stores data in the file, creating the file if it doesn't exist.
func (s *Storage) Write(ctx context.Context, data []byte) error {
s.m.Lock()
defer s.m.Unlock()

data, err := dpapi(encrypt, data)
if err != nil {
return err
}
err = os.WriteFile(s.p, data, 0600)
if errors.Is(err, os.ErrNotExist) {
dir := filepath.Dir(s.p)
if err = os.MkdirAll(dir, 0700); err == nil {
err = os.WriteFile(s.p, data, 0600)
}
}
return err
}

type operation int

const (
decrypt operation = iota
encrypt
)

func dpapi(op operation, data []byte) (result []byte, err error) {
out := windows.DataBlob{}
defer func() {
if out.Data != nil {
_, e := windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
// prefer returning DPAPI errors because they're more interesting than LocalFree errors
if e != nil && err == nil {
err = e
}
}
}()
in := windows.DataBlob{Data: &data[0], Size: uint32(len(data))}
switch op {
case decrypt:
// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptunprotectdata
err = windows.CryptUnprotectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
case encrypt:
// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptprotectdata
err = windows.CryptProtectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
default:
err = errors.New("invalid operation")
}
if err == nil {
result = make([]byte, out.Size)
copy(result, unsafe.Slice(out.Data, out.Size))
}
return result, err
}
33 changes: 33 additions & 0 deletions extensions/cache/accessor/windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

//go:build windows
// +build windows

package accessor

import (
"encoding/json"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestWriteEncryption(t *testing.T) {
p := filepath.Join(t.TempDir(), t.Name())
a, err := New(p)
require.NoError(t, err)

data := []byte(`{"key":"value"}`)
require.NoError(t, json.Unmarshal(data, &struct{}{}), "test bug: data should unmarshal")
require.NoError(t, a.Write(ctx, data))

// Write should have encrypted data before writing it to the file
actual, err := os.ReadFile(p)
require.NoError(t, err)
require.NotEmpty(t, actual)
err = json.Unmarshal(actual, &struct{}{})
require.Error(t, err, "Unmarshal should fail because the file's content, being encrypted, isn't JSON")
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.18
require (
github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0
github.com/stretchr/testify v1.8.2
golang.org/x/sys v0.8.0
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
)

Expand All @@ -17,6 +18,5 @@ require (
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.5.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
Expand Down