Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@

# Dependency directories (remove the comment below to include it)
# vendor/
.vscode/
80 changes: 80 additions & 0 deletions extensions/lock/lock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package lock

import (
"errors"
"fmt"
"os"
"sync"
"time"

"github.com/AzureAD/microsoft-authentication-extensions-for-go/flock"
)

type Lock struct {
retries int
retryDelay time.Duration

lockFile *os.File
lockfileName string

fLock *flock.Flock
mu sync.Mutex
}

type Option func(l *Lock)

func WithRetries(n int) Option {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Public requires a comment. I'd put:

// WithRetries changes the default number of retries from [retries] to n. Negative values panic()

Also, add panic on negative values.

return func(l *Lock) {
l.retries = n
}
}
func WithRetryDelay(t time.Duration) Option {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Public requires a comment. I'd put:

// WithRetryDelay changes the default delay from [delay] to t.

return func(l *Lock) {
l.retryDelay = t
}
}

func New(lockFileName string, options ...Option) (*Lock, error) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lockFileName can be just "p".

// New creates a Lock for a lock file at "p". If "p" doesn't exist, it will be created when using. Lock().
func New(p string options ...Option) (*Lock, error)

l := &Lock{}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add default value for retries and retryDelay.

for _, o := range options {
o(l)
}
l.fLock = flock.New(lockFileName)
l.lockfileName = lockFileName
return l, nil
}

func (l *Lock) Lock() error {
l.mu.Lock()
defer l.mu.Unlock()
for tryCount := 0; tryCount < l.retries; tryCount++ {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for i := 0; i < l.retries; i++ {
...

lockfile, err := os.OpenFile(lockFileName, os.O_RDWR|os.O_CREATE)
if err != nil {
return &Lock{}, err
}
err := l.fLock.Lock()
if err != nil {
time.Sleep(l.retryDelay * time.Millisecond)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be:
time.Sleep(l.retryDelay)

continue
} else {
if l.fLock.Locked() {
l.fLock.Fh.WriteString(fmt.Sprintf("{%d} {%s}", os.Getpid(), os.Args[0]))
}
return nil
}
}
return errors.New("failed to acquire lock")
}

func (l *Lock) UnLock() error {
if l.fLock != nil {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the "if l.fLock != nil", doesn't provide value.

if err := l.fLock.Unlock(); err != nil {
return err
}
l.lockFile.Close()
if err := os.Remove(l.fLock.Path()); err != nil {
return err
}
}
return nil
}
108 changes: 108 additions & 0 deletions extensions/lock/lock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package lock

import (
"bytes"
"fmt"
"log"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/AzureAD/microsoft-authentication-extensions-for-go/internal"
)

func spinThreads(noOfThreads int, sleepInterval time.Duration, t *testing.T) int {
cacheFile := "cache.txt"
var wg sync.WaitGroup
wg.Add(noOfThreads)
for i := 0; i < noOfThreads; i++ {
go func(i int) {
defer wg.Done()
acquireLockAndWriteToCache(i, sleepInterval, cacheFile)
}(i)
}
wg.Wait()
t.Cleanup(func() {
if err := os.Remove(cacheFile); err != nil {
log.Println("Failed to remove cache file", err)
}
})
return validateResult(cacheFile, t)
}

func acquireLockAndWriteToCache(threadNo int, sleepInterval time.Duration, cacheFile string) {
cacheAccessor := internal.NewFileAccessor(cacheFile)
lockfileName := cacheFile + ".lockfile"
l, err := New(lockfileName, WithRetries(60), WithRetryDelay(100))
if err := l.Lock(); err != nil {
log.Println("Couldn't acquire lock", err.Error())
return
}
defer l.UnLock()
data, err := cacheAccessor.Read()
if err != nil {
log.Println(err)
}
var buffer bytes.Buffer
buffer.Write(data)
buffer.WriteString(fmt.Sprintf("< %d \n", threadNo))
time.Sleep(sleepInterval * time.Millisecond)
buffer.WriteString(fmt.Sprintf("> %d \n", threadNo))
cacheAccessor.Write(buffer.Bytes())
}

func validateResult(cacheFile string, t *testing.T) int {
count := 0
var prevProc string = ""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var
(
count int
prevProc, tag, proc string
)
)

var tag string
var proc string
data, err := os.ReadFile(cacheFile)
if err != nil {
log.Println(err)
}
dat := string(data)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All in one line and call it s instead of ele

temp := strings.Split(dat, "\n")
for _, ele := range temp {
if ele != "" {
count += 1
split := strings.Split(ele, " ")
tag = split[0]
proc = split[1]
if prevProc != "" {
if proc != prevProc {
t.Fatal("Process overlap found")
}
if tag != ">" {
t.Fatal("Process overlap found")
}
prevProc = ""

} else {
if tag != "<" {
t.Fatal("Opening bracket not found")
}
prevProc = proc
}
}
}
return count
}
func TestForNormalWorkload(t *testing.T) {
noOfThreads := 4
sleepInterval := 100
n := spinThreads(noOfThreads, time.Duration(sleepInterval), t)
if n != 4*2 {
t.Fatalf("Should not observe starvation")
}
}

func TestForHighWorkload(t *testing.T) {
noOfThreads := 80
sleepInterval := 100
n := spinThreads(noOfThreads, time.Duration(sleepInterval), t)
if n > 80*2 {
t.Fatalf("Starvation or not, we should not observe garbled payload")
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add two tests:

One tests all the conditions of Lock().
Another tests all the conditions of Unlock()

  • flock.Unlock() has error
  • Lockfile can't be removed

You can do this by providing an interface:

type flocker interface{
TryLock() error
Unlock() error
Path() string
}

type fakeFlock struct {
err bool
}

func (f fakeFlock) TryLock() error {
if f.err {
return errors.New("error")
}
return nil
}

func (f fakeFlock) Unlock() error {
if f.err {
return errors.New("error")
}
return nil
}

10 changes: 10 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module github.com/AzureAD/microsoft-authentication-extensions-for-go

go 1.14

require (
github.com/AzureAD/microsoft-authentication-library-for-go v0.3.0
github.com/billgraziano/dpapi v0.4.0
github.com/gofrs/flock v0.8.1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flock line shouldn't. be here. So either a path change is needed or you need to run go tidy


)
197 changes: 197 additions & 0 deletions go.sum

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions internal/cache_accessor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package internal

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move these files under:

internal/accessor
accessor.go
file/
file.go
windows/
window.go

accessor.go:

  • Defines type Cache interface{}
    file.go:
  • Defines Accessor struct{}
  • Defines New() constructor
    windows.go
  • Defines Accessor struct{}
  • Defines New() constructor

All "Accessor" types implement accessor.Cache


import (
"io/ioutil"
"os"
)

type cacheAccessor interface {
Read() ([]byte, error)
Write(data []byte)
Delete()
}

type FileAccessor struct {
cacheFilePath string
}

func NewFileAccessor(cacheFilePath string) *FileAccessor {
return &FileAccessor{cacheFilePath: cacheFilePath}
}

func (f *FileAccessor) Read() ([]byte, error) {
var data []byte
file, err := os.Open(f.cacheFilePath)
if err != nil {
return nil, err
}
defer file.Close()
data, err = ioutil.ReadAll(file)
if err != nil {
return nil, err
}
return data, nil
}

func (f *FileAccessor) Write(data []byte) error {
err := ioutil.WriteFile(f.cacheFilePath, data, 0600)
if err != nil {
return err
}
return nil
}

func (f *FileAccessor) WriteAtomic(data []byte) {
// Not implemented yet
}

func (f *FileAccessor) Delete() {

}
60 changes: 60 additions & 0 deletions internal/windows_persistence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//go:build windows
// +build windows

package internal

import (
"io/ioutil"
"os"
"runtime"

"github.com/billgraziano/dpapi"
)

type WindowsAccessor struct {
cacheFilePath string
}

func NewWindowsAccessor(cacheFilePath string) *WindowsAccessor {
return &WindowsAccessor{cacheFilePath: cacheFilePath}
}

func (w *WindowsAccessor) Read() ([]byte, error) {
var data []byte
file, err := os.Open(w.cacheFilePath)
if err != nil {
return nil, err
}
defer file.Close()
data, err = ioutil.ReadAll(file)
if err != nil {
return nil, err
}
if data != nil && len(data) != 0 && runtime.GOOS == "windows" {
data, err = dpapi.DecryptBytes(data)
if err != nil {
return nil, err
}
}
return data, nil
}

func (w *WindowsAccessor) Write(data []byte) error {
data, err := dpapi.EncryptBytes(data)
if err != nil {
return err
}
err = ioutil.WriteFile(w.cacheFilePath, data, 0600)
if err != nil {
return err
}
return nil
}

func (w *WindowsAccessor) WriteAtomic(data []byte) {
// Not implemented yet
}

func (w *WindowsAccessor) Delete() {

}
Loading