Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions passwordless.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,28 @@ func (p *Passwordless) RequestToken(ctx context.Context, s, uid, recipient strin
}
}

// PostVerifyAction is an action to take after validation has succesfully occured.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// PostVerifyAction is an action to take after validation has succesfully occured.
// PostVerifyAction is an action to take if validation succeeds.

type PostVerifyAction func(ctx context.Context, s TokenStore, uid, token string, valid bool) (bool, error)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear on the purpose of the bool in the return values here. When would it ever be different to the incoming valid value?


// WithValidDelete when a token is a valid, this deletes it from the store.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// WithValidDelete when a token is a valid, this deletes it from the store.
// WithValidDelete deletes a token from the store when validation succeeds.

func WithValidDelete() PostVerifyAction {
return func(ctx context.Context, s TokenStore, uid, _ string, valid bool) (bool, error) {
if valid {
return valid, s.Delete(ctx, uid)
}
return valid, nil
Comment on lines +122 to +125
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if valid {
return valid, s.Delete(ctx, uid)
}
return valid, nil
if valid {
return true, s.Delete(ctx, uid)
}
return false, nil

}
}

// VerifyToken verifies the provided token is valid.
func (p *Passwordless) VerifyToken(ctx context.Context, uid, token string) (bool, error) {
return VerifyToken(ctx, p.Store, uid, token)
}

func (p *Passwordless) VerifyTokenWithOptions(ctx context.Context, uid, token string, actions ...PostVerifyAction) (bool, error) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing method doc comment.

return VerifyTokenWithOptions(ctx, p.Store, uid, token, actions...)
}

// RequestToken generates, saves and delivers a token to the specified
// recipient.
func RequestToken(ctx context.Context, s TokenStore, t Strategy, uid, recipient string) error {
Expand All @@ -136,16 +153,24 @@ func RequestToken(ctx context.Context, s TokenStore, t Strategy, uid, recipient
return nil
}

// VerifyToken checks the given token against the provided token store.
// VerifyToken checks the given token against the provided token store, on successful
// validation it deletes the token.
func VerifyToken(ctx context.Context, s TokenStore, uid, token string) (bool, error) {
if isValid, err := s.Verify(ctx, token, uid); err != nil {
return VerifyTokenWithOptions(ctx, s, uid, token, WithValidDelete())
}

// VerifyTokenWithOptions
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incomplete func doc comment here. It would be good to clearly describe how actions work, and what happens if an action returns an error, for example.

func VerifyTokenWithOptions(ctx context.Context, s TokenStore, uid, token string, actions ...PostVerifyAction) (bool, error) {
isValid, err := s.Verify(ctx, token, uid)
if err != nil {
// Failed to validate
return false, err
} else if !isValid {
// Token is not valid
return false, nil
} else {
// Token *is* valid; remove old token
return true, s.Delete(ctx, uid)
}
for _, action := range actions {
isValid, err = action(ctx, s, uid, token, isValid)
if err != nil {
return isValid, err
}
}
return isValid, nil
}
76 changes: 61 additions & 15 deletions passwordless_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"context"

"github.com/stretchr/testify/assert"
)

type testTransport struct {
Expand Down Expand Up @@ -35,32 +36,33 @@ func (g testGenerator) Sanitize(ctx context.Context, s string) (string, error) {
}

func TestPasswordless(t *testing.T) {
ctx := context.TODO()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests should use the background context.

Suggested change
ctx := context.TODO()
ctx := context.Background()

p := New(NewMemStore())

tt := &testTransport{}
tg := &testGenerator{token: "1337"}
s := p.SetTransport("test", tt, tg, 5*time.Minute)

// Check transports match those set
assert.Equal(t, map[string]Strategy{"test": s}, p.ListStrategies(nil))
if s0, err := p.GetStrategy(nil, "test"); err != nil {
assert.Equal(t, map[string]Strategy{"test": s}, p.ListStrategies(ctx))
if s0, err := p.GetStrategy(ctx, "test"); err != nil {
assert.NoError(t, err)
} else {
assert.Equal(t, s, s0)
}

// Check returned token is as expected
assert.NoError(t, p.RequestToken(nil, "test", "uid", "recipient"))
assert.NoError(t, p.RequestToken(ctx, "test", "uid", "recipient"))
assert.Equal(t, tt.token, tg.token)
assert.Equal(t, tt.recipient, "recipient")

// Check invalid token is rejected
v, err := p.VerifyToken(nil, "uid", "badtoken")
v, err := p.VerifyToken(ctx, "uid", "badtoken")
assert.NoError(t, err)
assert.False(t, v)

// Verify token
v, err = p.VerifyToken(nil, "uid", tg.token)
v, err = p.VerifyToken(ctx, "uid", tg.token)
assert.NoError(t, err)
assert.True(t, v)
}
Expand All @@ -75,30 +77,32 @@ func (s testStrategy) Valid(c context.Context) bool {
}

func TestPasswordlessFailures(t *testing.T) {
ctx := context.TODO()
p := New(NewMemStore())

_, err := p.GetStrategy(nil, "madeup")
_, err := p.GetStrategy(ctx, "madeup")
assert.Equal(t, err, ErrUnknownStrategy)

err = p.RequestToken(nil, "madeup", "", "")
err = p.RequestToken(ctx, "madeup", "", "")
assert.Equal(t, err, ErrUnknownStrategy)

p.SetStrategy("unfriendly", testStrategy{valid: false})

err = p.RequestToken(nil, "unfriendly", "", "")
err = p.RequestToken(ctx, "unfriendly", "", "")
assert.Equal(t, err, ErrNotValidForContext)
}

func TestRequestToken(t *testing.T) {
ctx := context.TODO()
// Test Generate()
assert.EqualError(t, RequestToken(nil, nil, &mockStrategy{
assert.EqualError(t, RequestToken(ctx, nil, &mockStrategy{
generate: func(c context.Context) (string, error) {
return "", fmt.Errorf("refused generate")
},
}, "", ""), "refused generate", "Generate() error should propagate")

// Test Send()
assert.EqualError(t, RequestToken(nil, &mockTokenStore{
assert.EqualError(t, RequestToken(ctx, &mockTokenStore{
store: func(ctx context.Context, token, uid string, ttl time.Duration) error {
return nil
},
Expand All @@ -112,7 +116,7 @@ func TestRequestToken(t *testing.T) {
}, "", ""), "refused send", "Send() error should propagate")

// Test Store()
err := RequestToken(nil, &mockTokenStore{
err := RequestToken(ctx, &mockTokenStore{
store: func(ctx context.Context, token, uid string, ttl time.Duration) error {
return fmt.Errorf("refused store")
},
Expand All @@ -128,23 +132,24 @@ func TestRequestToken(t *testing.T) {
}

func TestVerifyToken(t *testing.T) {
valid, err := VerifyToken(nil, &mockTokenStore{
ctx := context.TODO()
valid, err := VerifyToken(ctx, &mockTokenStore{
verify: func(ctx context.Context, token, uid string) (bool, error) {
return false, fmt.Errorf("refused verify")
},
}, "", "")
assert.False(t, valid)
assert.EqualError(t, err, "refused verify", "Verify() error should propagate")

valid, err = VerifyToken(nil, &mockTokenStore{
valid, err = VerifyToken(ctx, &mockTokenStore{
verify: func(ctx context.Context, token, uid string) (bool, error) {
return false, nil
},
}, "", "")
assert.False(t, valid)
assert.NoError(t, err)

valid, err = VerifyToken(nil, &mockTokenStore{
valid, err = VerifyToken(ctx, &mockTokenStore{
verify: func(ctx context.Context, token, uid string) (bool, error) {
return true, nil
},
Expand All @@ -156,6 +161,47 @@ func TestVerifyToken(t *testing.T) {
assert.EqualError(t, err, "delete failure")
}

func TestVerifyTokenWithOptions(t *testing.T) {
ctx := context.TODO()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests should use the background context.

valid, err := VerifyTokenWithOptions(ctx, &mockTokenStore{
verify: func(ctx context.Context, token, uid string) (bool, error) {
return false, fmt.Errorf("refused verify")
},
}, "", "")
assert.False(t, valid)
assert.EqualError(t, err, "refused verify", "Verify() error should propagate")

valid, err = VerifyTokenWithOptions(ctx, &mockTokenStore{
verify: func(ctx context.Context, token, uid string) (bool, error) {
return false, nil
},
}, "", "")
assert.False(t, valid)
assert.NoError(t, err)

valid, err = VerifyTokenWithOptions(ctx, &mockTokenStore{
verify: func(ctx context.Context, token, uid string) (bool, error) {
return true, nil
},
delete: func(ctx context.Context, uid string) error {
return fmt.Errorf("delete failure")
},
}, "", "")
assert.True(t, valid)
assert.NoError(t, err)

valid, err = VerifyTokenWithOptions(ctx, &mockTokenStore{
verify: func(ctx context.Context, token, uid string) (bool, error) {
return true, nil
},
delete: func(ctx context.Context, uid string) error {
return fmt.Errorf("delete failure")
},
}, "", "", WithValidDelete())
assert.True(t, valid)
assert.EqualError(t, err, "delete failure")
}

type mockStrategy struct {
SimpleStrategy
generate func(context.Context) (string, error)
Expand Down
37 changes: 22 additions & 15 deletions store_redis_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package passwordless

import (
"context"
"log"
"testing"
"time"
Expand All @@ -25,7 +26,7 @@ func newRedisMock() *redisMock {
}
}

func (r redisMock) Set(key string, value interface{}, expiration time.Duration) *redis.StatusCmd {
func (r redisMock) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd {
val := rval{
d: expiration,
}
Expand All @@ -37,7 +38,7 @@ func (r redisMock) Set(key string, value interface{}, expiration time.Duration)
return redis.NewStatusResult(key, nil)
}

func (r redisMock) TTL(key string) *redis.DurationCmd {
func (r redisMock) TTL(ctx context.Context, key string) *redis.DurationCmd {
v, ok := r.store[key]
if !ok {
return redis.NewDurationResult(-1*time.Second, nil)
Expand All @@ -46,7 +47,7 @@ func (r redisMock) TTL(key string) *redis.DurationCmd {
return cmd
}

func (r redisMock) Get(key string) *redis.StringCmd {
func (r redisMock) Get(ctx context.Context, key string) *redis.StringCmd {
v, ok := r.store[key]
if !ok {
return redis.NewStringResult("", redis.Nil)
Expand All @@ -58,58 +59,64 @@ func (r redisMock) Get(key string) *redis.StringCmd {
return redis.NewStringResult(v.v, nil)
}

func (r redisMock) Del(keys ...string) *redis.IntCmd {
func (r redisMock) Del(ctx context.Context, keys ...string) *redis.IntCmd {
for _, k := range keys {
delete(r.store, k)
}
return redis.NewIntResult(1, nil)
}

func TestRedisStore(t *testing.T) {
ctx := context.TODO()
ms := NewRedisStore(newRedisMock())
assert.NotNil(t, ms)

b, exp, err := ms.Exists(nil, "uid")
b, exp, err := ms.Exists(ctx, "uid")
assert.False(t, b)
assert.True(t, exp.IsZero())
assert.NoError(t, err)

err = ms.Store(nil, "", "uid", -time.Hour)
b, exp, err = ms.Exists(nil, "uid")
err = ms.Store(ctx, "", "uid", -time.Hour)
assert.NoError(t, err)
b, exp, err = ms.Exists(ctx, "uid")
assert.False(t, b)
assert.True(t, exp.IsZero())
assert.NoError(t, err)

err = ms.Store(nil, "", "uid", time.Hour)
b, exp, err = ms.Exists(nil, "uid")
err = ms.Store(ctx, "", "uid", time.Hour)
assert.NoError(t, err)
b, exp, err = ms.Exists(ctx, "uid")
log.Println(b, exp, err)
assert.True(t, b)
assert.False(t, exp.IsZero())
}

func TestRedisStoreVerify(t *testing.T) {
ctx := context.TODO()
ms := NewRedisStore(newRedisMock())
assert.NotNil(t, ms)

// Token doesn't exist
b, err := ms.Verify(nil, "badtoken", "uid")
b, err := ms.Verify(ctx, "badtoken", "uid")
assert.False(t, b)
assert.Equal(t, ErrTokenNotFound, err)

// Token expired
err = ms.Store(nil, "", "uid", -time.Hour)
b, err = ms.Verify(nil, "badtoken", "uid")
err = ms.Store(ctx, "", "uid", -time.Hour)
assert.NoError(t, err)
b, err = ms.Verify(ctx, "badtoken", "uid")
assert.False(t, b)
assert.Equal(t, ErrTokenNotFound, err)

// Token wrong
err = ms.Store(nil, "token", "uid", time.Hour)
b, err = ms.Verify(nil, "badtoken", "uid")
err = ms.Store(ctx, "token", "uid", time.Hour)
assert.NoError(t, err)
b, err = ms.Verify(ctx, "badtoken", "uid")
assert.False(t, b)
assert.NoError(t, err)

// Token correct
b, err = ms.Verify(nil, "token", "uid")
b, err = ms.Verify(ctx, "token", "uid")
assert.True(t, b)
assert.NoError(t, err)
}