Skip to content

Commit d5f43db

Browse files
authored
introduce [Header|Query|Cookie]Auth helper functions to streamline token extraction (#47)
1 parent 2b23793 commit d5f43db

File tree

7 files changed

+174
-38
lines changed

7 files changed

+174
-38
lines changed

config/config_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ func TestComplex(t *testing.T) {
118118
{"Complex", func() { config.File = ".envCOMPLEX" }, &ComplexConfig{}, nil, `{"TestMap":{"one":"1","two":"2"},"Time":"2022-01-01T00:00:00Z","Patterns":["123","asdf","https://a.b/c","^http:*"]}`, false},
119119
{"BadTime", func() { config.File = ".envBADTIME" }, &ComplexConfig{}, nil, `{"TestMap":null,"Time":"0001-01-01T00:00:00Z"}`, true},
120120
{"BadMap", nil, &ComplexConfig{}, map[string]string{"TEST_MAP": "FOO"}, `{"TestMap":{},"Time":"2021-01-01T00:00:00Z"}`, false},
121+
{"BadRegex", func() { config.File = ".envEMPTY" }, &ComplexConfig{}, map[string]string{"PATTERNS": "a(b"}, `{"TestMap":null,"Time":"0001-01-01T00:00:00Z","Patterns":[""]}`, true},
121122
}
122123

123124
for _, tt := range tests {
@@ -160,7 +161,7 @@ type ExampleConfig struct {
160161

161162
func ExampleLoad() {
162163
cfg := ExampleConfig{}
163-
config.Load(&cfg)
164+
_ = config.Load(&cfg)
164165
fmt.Printf("DebugMode=%v\n", cfg.DebugMode)
165166
fmt.Printf("Port=%v\n", cfg.Port)
166167
fmt.Printf("DB=%v\n", cfg.DB)
@@ -172,7 +173,7 @@ func TestFoo(t *testing.T) {
172173

173174
cfg := ExampleConfig{}
174175

175-
config.Load(&cfg)
176+
_ = config.Load(&cfg)
176177
test := viper.GetStringSlice("TEST_ARRAY")
177178
fmt.Printf("TEST=%#v\n", cfg.Test)
178179
fmt.Printf("test=%#v\n", test)

config/hooks.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ func StringToURLHookFunc(f reflect.Type, t reflect.Type, data any) (any, error)
8181
// ErrInvalidTime is returned when a time tag fails to parse.
8282
var ErrInvalidTime = errors.New("failed parsing time")
8383

84+
// ErrInvalidRegex is returned when a regex tag fails to parse.
85+
var ErrInvalidRegex = errors.New("failed parsing regex")
86+
8487
// StringToTimeFunc converts strings to time.Time.
8588
func StringToTimeFunc(f reflect.Type, t reflect.Type, data any) (any, error) {
8689
if f.Kind() != reflect.String {
@@ -121,7 +124,7 @@ func StringToRegexFunc(f reflect.Type, t reflect.Type, data any) (any, error) {
121124

122125
out, err := regexp.Compile(s)
123126
if err != nil {
124-
return nil, fmt.Errorf("%w: `%v`", ErrInvalidTime, data)
127+
return nil, fmt.Errorf("%w: %q", ErrInvalidRegex, s)
125128
}
126129

127130
return out, nil

currency/fixed_format_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package currency_test
33
import (
44
"testing"
55

6+
"github.com/stretchr/testify/assert"
7+
68
"github.com/bir/iken/currency"
79
)
810

@@ -37,3 +39,8 @@ func TestFixedFormatter_Format(t *testing.T) {
3739
})
3840
}
3941
}
42+
43+
func TestPower10(t *testing.T) {
44+
assert.Equal(t, currency.Power10(0), 1)
45+
assert.Equal(t, currency.Power10(1), 10)
46+
}

httputil/auth.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ const (
3232
// request it returns an error.
3333
type AuthenticateFunc[T any] func(r *http.Request) (T, error)
3434

35+
// TokenAuthenticatorFunc is the signature of a function used to authenticate a request given just the token.
36+
// Given a request, it returns the authenticated user. If unable to authenticate the
37+
// request it returns an error.
38+
type TokenAuthenticatorFunc[T any] func(ctx context.Context, token string) (T, error)
39+
3540
// AuthorizeFunc is the signature of a function used to authorize a request. If unable
3641
// to authorize the user it returns an error.
3742
type AuthorizeFunc[T any] func(ctx context.Context, user T, scopes []string) error
@@ -44,6 +49,45 @@ type AuthCheck[T any] struct {
4449
scopes []string
4550
}
4651

52+
func HeaderAuth[T any](key string, fn TokenAuthenticatorFunc[T]) AuthenticateFunc[T] {
53+
return func(r *http.Request) (T, error) {
54+
var empty T
55+
56+
token := r.Header.Get(key)
57+
if token == "" {
58+
return empty, ErrUnauthorized
59+
}
60+
61+
return fn(r.Context(), token)
62+
}
63+
}
64+
65+
func QueryAuth[T any](key string, fn TokenAuthenticatorFunc[T]) AuthenticateFunc[T] {
66+
return func(r *http.Request) (T, error) {
67+
var empty T
68+
69+
token := r.URL.Query().Get(key)
70+
if token == "" {
71+
return empty, ErrUnauthorized
72+
}
73+
74+
return fn(r.Context(), token)
75+
}
76+
}
77+
78+
func CookieAuth[T any](key string, fn TokenAuthenticatorFunc[T]) AuthenticateFunc[T] {
79+
return func(r *http.Request) (T, error) {
80+
var empty T
81+
82+
cookie, err := r.Cookie(key)
83+
if err != nil || cookie == nil || len(cookie.Value) == 0 {
84+
return empty, ErrUnauthorized
85+
}
86+
87+
return fn(r.Context(), cookie.Value)
88+
}
89+
}
90+
4791
func NewAuthCheck[T any](authenticate AuthenticateFunc[T], authorize AuthorizeFunc[T], scopes ...string) AuthCheck[T] {
4892
return AuthCheck[T]{
4993
authenticate: authenticate,

httputil/auth_test.go

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package httputil_test
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"net/http"
78
"net/http/httptest"
89
"slices"
@@ -32,7 +33,7 @@ func authenticate(r *http.Request) (string, error) {
3233
return "", errors.New("missing")
3334
}
3435

35-
func authorize(ctx context.Context, user string, scopes []string) error {
36+
func authorize(_ context.Context, user string, scopes []string) error {
3637
if slices.Contains(scopes, user) {
3738
return nil
3839
}
@@ -140,3 +141,111 @@ func TestSecurityGroups_Auth(t *testing.T) {
140141
})
141142
}
142143
}
144+
145+
func strAuth(_ context.Context, token string) (string, error) {
146+
if token == "" {
147+
return "", errors.New("unreachable")
148+
}
149+
150+
if token == "good" {
151+
return "good", nil
152+
}
153+
154+
if token == "bad" {
155+
return "", errors.New("bad")
156+
}
157+
158+
return "", errors.New("badder")
159+
}
160+
161+
func TestHeaderAuth(t *testing.T) {
162+
type testCase[T any] struct {
163+
name string
164+
key string
165+
val string
166+
fn httputil.TokenAuthenticatorFunc[T]
167+
want string
168+
err error
169+
}
170+
tests := []testCase[string]{
171+
{"Empty", "Missing", "", strAuth, "", httputil.ErrUnauthorized},
172+
{"Good", "Authorization", "good", strAuth, "good", nil},
173+
{"Bad", "Authorization", "bad", strAuth, "", errors.New("bad")},
174+
{"other", "Authorization", "other", strAuth, "", errors.New("badder")},
175+
}
176+
for _, tt := range tests {
177+
t.Run(tt.name, func(t *testing.T) {
178+
r := httptest.NewRequest("FOO", "/asdf", nil)
179+
r.Header.Set(tt.key, tt.val)
180+
181+
got, err := httputil.HeaderAuth(tt.key, tt.fn)(r)
182+
if tt.err != nil {
183+
assert.Equal(t, tt.err, err)
184+
}
185+
186+
assert.Equalf(t, tt.want, got, "HeaderAuth(%v, %v)", tt.key, tt.fn)
187+
})
188+
}
189+
}
190+
191+
func TestQueryAuth(t *testing.T) {
192+
type testCase[T any] struct {
193+
name string
194+
key string
195+
val string
196+
fn httputil.TokenAuthenticatorFunc[T]
197+
want string
198+
err error
199+
}
200+
tests := []testCase[string]{
201+
{"Empty", "Missing", "", strAuth, "", httputil.ErrUnauthorized},
202+
{"Good", "Authorization", "good", strAuth, "good", nil},
203+
{"Bad", "Authorization", "bad", strAuth, "", errors.New("bad")},
204+
{"other", "Authorization", "other", strAuth, "", errors.New("badder")},
205+
}
206+
for _, tt := range tests {
207+
t.Run(tt.name, func(t *testing.T) {
208+
r := httptest.NewRequest("FOO", fmt.Sprintf("/asdf?%s=%s", tt.key, tt.val), nil)
209+
210+
got, err := httputil.QueryAuth(tt.key, tt.fn)(r)
211+
if tt.err != nil {
212+
assert.Equal(t, tt.err, err)
213+
}
214+
215+
assert.Equal(t, tt.want, got)
216+
})
217+
}
218+
}
219+
220+
func TestCookieAuth(t *testing.T) {
221+
type testCase[T any] struct {
222+
name string
223+
key string
224+
val string
225+
fn httputil.TokenAuthenticatorFunc[T]
226+
want string
227+
err error
228+
}
229+
tests := []testCase[string]{
230+
{"Empty", "Missing", "", strAuth, "", httputil.ErrUnauthorized},
231+
{"Good", "Authorization", "good", strAuth, "good", nil},
232+
{"Bad", "Authorization", "bad", strAuth, "", errors.New("bad")},
233+
{"other", "Authorization", "other", strAuth, "", errors.New("badder")},
234+
}
235+
for _, tt := range tests {
236+
t.Run(tt.name, func(t *testing.T) {
237+
r := httptest.NewRequest("FOO", "/asdf", nil)
238+
r.AddCookie(&http.Cookie{
239+
Name: tt.key,
240+
Value: tt.val,
241+
})
242+
243+
got, err := httputil.CookieAuth(tt.key, tt.fn)(r)
244+
if tt.err != nil {
245+
assert.Equal(t, tt.err, err)
246+
}
247+
248+
assert.Equal(t, tt.want, got)
249+
})
250+
}
251+
}

httputil/wrap_writer.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,6 @@ func (b *basicWriter) Write(buf []byte) (int, error) {
8282
return n, err
8383
}
8484

85-
func (b *basicWriter) maybeWriteHeader() {
86-
if !b.wroteHeader {
87-
b.WriteHeader(http.StatusOK)
88-
}
89-
}
90-
9185
func (b *basicWriter) Status() int {
9286
return b.code
9387
}
@@ -110,6 +104,12 @@ func (b *basicWriter) Flush() {
110104
}
111105
}
112106

107+
func (b *basicWriter) maybeWriteHeader() {
108+
if !b.wroteHeader {
109+
b.WriteHeader(http.StatusOK)
110+
}
111+
}
112+
113113
// fancyWriter is a writer that additionally satisfies http.CloseNotifier,
114114
// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case
115115
// of wrapping the http.ResponseWriter that package http gives you, in order to

pgxutil/scan.go

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)