Skip to content

Commit 1a44627

Browse files
authored
Introduce AccessKey type for improved type safety (#46)
1 parent f1dda52 commit 1a44627

File tree

10 files changed

+87
-74
lines changed

10 files changed

+87
-74
lines changed

access_key.go

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"net/http"
87
"strings"
98

109
"crypto/rand"
1110
"encoding/binary"
1211

13-
"github.com/go-chi/transport"
1412
"github.com/goware/base64"
1513
"github.com/jxskiss/base62"
1614
)
@@ -24,10 +22,16 @@ var (
2422
ErrInvalidKeyLength = errors.New("invalid access key length")
2523
)
2624

27-
func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) {
25+
type AccessKey string
26+
27+
func (a AccessKey) String() string {
28+
return string(a)
29+
}
30+
31+
func (a AccessKey) GetProjectID() (projectID uint64, err error) {
2832
var errs []error
2933
for _, e := range SupportedEncodings {
30-
projectID, err := e.Decode(accessKey)
34+
projectID, err := e.Decode(a)
3135
if err != nil {
3236
errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err))
3337
continue
@@ -37,44 +41,34 @@ func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) {
3741
return 0, errors.Join(errs...)
3842
}
3943

40-
func GenerateAccessKey(ctx context.Context, projectID uint64) string {
41-
version, ok := GetVersion(ctx)
42-
if !ok {
43-
return DefaultEncoding.Encode(ctx, projectID)
44-
}
45-
46-
for _, e := range SupportedEncodings {
47-
if e.Version() == version {
48-
return e.Encode(ctx, projectID)
49-
}
50-
}
51-
return ""
52-
}
53-
54-
func GetAccessKeyPrefix(accessKey string) string {
55-
parts := strings.Split(accessKey, Separator)
44+
func (a AccessKey) GetPrefix() string {
45+
parts := strings.Split(a.String(), Separator)
5646
if len(parts) < 2 {
5747
return ""
5848
}
5949
return strings.Join(parts[:len(parts)-1], Separator)
6050
}
6151

62-
func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
63-
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
64-
r := transport.CloneRequest(req)
52+
var ErrUnsupportedEncoding = errors.New("unsupported access key encoding")
6553

66-
if accessKey, ok := GetAccessKey(req.Context()); ok {
67-
r.Header.Set(HeaderAccessKey, accessKey)
68-
}
54+
func GenerateAccessKey(ctx context.Context, projectID uint64) (AccessKey, error) {
55+
version, ok := GetVersion(ctx)
56+
if !ok {
57+
return DefaultEncoding.Encode(ctx, projectID), nil
58+
}
6959

70-
return next.RoundTrip(r)
71-
})
60+
for _, e := range SupportedEncodings {
61+
if e.Version() == version {
62+
return e.Encode(ctx, projectID), nil
63+
}
64+
}
65+
return "", ErrUnsupportedEncoding
7266
}
7367

7468
type Encoding interface {
7569
Version() byte
76-
Encode(ctx context.Context, projectID uint64) string
77-
Decode(accessKey string) (projectID uint64, err error)
70+
Encode(ctx context.Context, projectID uint64) AccessKey
71+
Decode(accessKey AccessKey) (projectID uint64, err error)
7872
}
7973

8074
const (
@@ -89,15 +83,15 @@ type V0 struct{}
8983

9084
func (V0) Version() byte { return 0 }
9185

92-
func (V0) Encode(_ context.Context, projectID uint64) string {
86+
func (V0) Encode(_ context.Context, projectID uint64) AccessKey {
9387
buf := make([]byte, sizeV0)
9488
binary.BigEndian.PutUint64(buf, projectID)
9589
_, _ = rand.Read(buf[8:])
96-
return base62.EncodeToString(buf)
90+
return AccessKey(base62.EncodeToString(buf))
9791
}
9892

99-
func (V0) Decode(accessKey string) (projectID uint64, err error) {
100-
buf, err := base62.DecodeString(accessKey)
93+
func (V0) Decode(accessKey AccessKey) (projectID uint64, err error) {
94+
buf, err := base62.DecodeString(accessKey.String())
10195
if err != nil {
10296
return 0, fmt.Errorf("base62 decode: %w", err)
10397
}
@@ -113,16 +107,16 @@ type V1 struct{}
113107

114108
func (V1) Version() byte { return 1 }
115109

116-
func (v V1) Encode(_ context.Context, projectID uint64) string {
110+
func (v V1) Encode(_ context.Context, projectID uint64) AccessKey {
117111
buf := make([]byte, sizeV1)
118112
buf[0] = v.Version()
119113
binary.BigEndian.PutUint64(buf[1:], projectID)
120114
_, _ = rand.Read(buf[9:])
121-
return base64.Base64UrlEncode(buf)
115+
return AccessKey(base64.Base64UrlEncode(buf))
122116
}
123117

124-
func (V1) Decode(accessKey string) (projectID uint64, err error) {
125-
buf, err := base64.Base64UrlDecode(accessKey)
118+
func (V1) Decode(accessKey AccessKey) (projectID uint64, err error) {
119+
buf, err := base64.Base64UrlDecode(accessKey.String())
126120
if err != nil {
127121
return 0, fmt.Errorf("base64 decode: %w", err)
128122
}
@@ -143,19 +137,19 @@ const (
143137

144138
func (V2) Version() byte { return 2 }
145139

146-
func (v V2) Encode(ctx context.Context, projectID uint64) string {
140+
func (v V2) Encode(ctx context.Context, projectID uint64) AccessKey {
147141
buf := make([]byte, sizeV2)
148142
buf[0] = v.Version()
149143
binary.BigEndian.PutUint64(buf[1:], projectID)
150144
_, _ = rand.Read(buf[9:])
151-
return getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf)
145+
return AccessKey(getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf))
152146
}
153147

154-
func (V2) Decode(accessKey string) (projectID uint64, err error) {
155-
parts := strings.Split(accessKey, Separator)
156-
accessKey = parts[len(parts)-1]
148+
func (V2) Decode(accessKey AccessKey) (projectID uint64, err error) {
149+
parts := strings.Split(accessKey.String(), Separator)
150+
raw := parts[len(parts)-1]
157151

158-
buf, err := base64.Base64UrlDecode(accessKey)
152+
buf, err := base64.Base64UrlDecode(raw)
159153
if err != nil {
160154
return 0, fmt.Errorf("base64 decode: %w", err)
161155
}

access_key_test.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,50 +14,55 @@ func TestAccessKeyEncoding(t *testing.T) {
1414
t.Run("v0", func(t *testing.T) {
1515
ctx := authcontrol.WithVersion(context.Background(), 0)
1616
projectID := uint64(12345)
17-
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
17+
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
18+
require.NoError(t, err)
1819
t.Log("=> k", accessKey)
1920

20-
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
21+
outID, err := accessKey.GetProjectID()
2122
require.NoError(t, err)
2223
require.Equal(t, projectID, outID)
2324
})
2425

2526
t.Run("v1", func(t *testing.T) {
2627
ctx := authcontrol.WithVersion(context.Background(), 1)
2728
projectID := uint64(12345)
28-
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
29+
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
30+
require.NoError(t, err)
2931
t.Log("=> k", accessKey)
30-
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
32+
outID, err := accessKey.GetProjectID()
3133
require.NoError(t, err)
3234
require.Equal(t, projectID, outID)
3335
})
3436
t.Run("v2", func(t *testing.T) {
3537
ctx := authcontrol.WithVersion(context.Background(), 2)
3638
projectID := uint64(12345)
37-
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
38-
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
39-
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
39+
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
40+
require.NoError(t, err)
41+
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
42+
outID, err := accessKey.GetProjectID()
4043
require.NoError(t, err)
4144
require.Equal(t, projectID, outID)
4245

4346
ctx = authcontrol.WithPrefix(ctx, "newprefix:dev")
4447

45-
accessKey2 := authcontrol.GenerateAccessKey(ctx, projectID)
46-
t.Log("=> k", accessKey2, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey2))
47-
outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey2)
48+
accessKey2, err := authcontrol.GenerateAccessKey(ctx, projectID)
49+
require.NoError(t, err)
50+
t.Log("=> k", accessKey2, "| prefix =>", accessKey2.GetPrefix())
51+
outID, err = accessKey2.GetProjectID()
4852
require.NoError(t, err)
4953
require.Equal(t, projectID, outID)
5054
// retrocompatibility with the older prefix
51-
outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey)
55+
outID, err = accessKey.GetProjectID()
5256
require.NoError(t, err)
5357
require.Equal(t, projectID, outID)
5458
})
5559
}
5660

5761
func TestDecode(t *testing.T) {
5862
ctx := authcontrol.WithVersion(context.Background(), 2)
59-
accessKey := authcontrol.GenerateAccessKey(ctx, 237)
60-
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
63+
accessKey, err := authcontrol.GenerateAccessKey(ctx, 237)
64+
require.NoError(t, err)
65+
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
6166
}
6267

6368
func TestForwardAccessKeyTransport(t *testing.T) {
@@ -71,7 +76,7 @@ func TestForwardAccessKeyTransport(t *testing.T) {
7176

7277
// Create context with access key
7378
accessKey := "test-access-key-123"
74-
ctx := authcontrol.WithAccessKey(context.Background(), accessKey)
79+
ctx := authcontrol.WithAccessKey(context.Background(), authcontrol.AccessKey(accessKey))
7580

7681
// Create HTTP client with ForwardAccessKeyTransport
7782
client := &http.Client{

cmd/access_key/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ var decodeCmd = &cobra.Command{
2828
if len(args) != 1 {
2929
return fmt.Errorf("access key is required")
3030
}
31-
accessKey := args[0]
31+
accessKey := authcontrol.AccessKey(args[0])
3232
var (
3333
projectID uint64
3434
version byte

common.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/0xsequence/authcontrol/proto"
1616
"github.com/go-chi/jwtauth/v5"
17+
"github.com/go-chi/transport"
1718
"github.com/lestrrat-go/jwx/v2/jwa"
1819
"github.com/lestrrat-go/jwx/v2/jwt"
1920
)
@@ -22,10 +23,10 @@ const (
2223
HeaderAccessKey = "X-Access-Key"
2324
)
2425

25-
type AccessKeyFunc func(*http.Request) string
26+
type AccessKeyFunc func(*http.Request) AccessKey
2627

27-
func AccessKeyFromHeader(r *http.Request) string {
28-
return r.Header.Get(HeaderAccessKey)
28+
func AccessKeyFromHeader(r *http.Request) AccessKey {
29+
return AccessKey(r.Header.Get(HeaderAccessKey))
2930
}
3031

3132
type ErrHandler func(r *http.Request, w http.ResponseWriter, err error)
@@ -198,3 +199,16 @@ func findProjectClaim(r *http.Request) (uint64, error) {
198199
return 0, fmt.Errorf("invalid type: %T", val)
199200
}
200201
}
202+
203+
// ForwardAccessKeyTransport is a RoundTripper that forwards the access key from the request context to the request header.
204+
func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
205+
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
206+
r := transport.CloneRequest(req)
207+
208+
if accessKey, ok := GetAccessKey(req.Context()); ok {
209+
r.Header.Set(HeaderAccessKey, accessKey.String())
210+
}
211+
212+
return next.RoundTrip(r)
213+
})
214+
}

common_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ import (
1717

1818
const HeaderKey = "Test-Key"
1919

20-
func keyFunc(r *http.Request) string {
21-
return r.Header.Get(HeaderKey)
20+
func keyFunc(r *http.Request) authcontrol.AccessKey {
21+
return authcontrol.AccessKey(r.Header.Get(HeaderKey))
2222
}
2323

2424
type requestOption func(r *http.Request)

context.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ func GetService(ctx context.Context) (string, bool) {
102102
// WithAccessKey adds the access key to the context.
103103
//
104104
// TODO: Deprecate this in favor of Session middleware with a JWT token.
105-
func WithAccessKey(ctx context.Context, accessKey string) context.Context {
105+
func WithAccessKey(ctx context.Context, accessKey AccessKey) context.Context {
106106
return context.WithValue(ctx, ctxKeyAccessKey, accessKey)
107107
}
108108

109109
// GetAccessKey returns the access key from the context.
110-
func GetAccessKey(ctx context.Context) (string, bool) {
111-
v, ok := ctx.Value(ctxKeyAccessKey).(string)
110+
func GetAccessKey(ctx context.Context) (AccessKey, bool) {
111+
v, ok := ctx.Value(ctxKeyAccessKey).(AccessKey)
112112
return v, ok
113113
}
114114

go.mod

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ require (
99
github.com/go-chi/metrics v0.1.0
1010
github.com/go-chi/traceid v0.2.0
1111
github.com/go-chi/transport v0.4.0
12+
github.com/goware/base64 v0.1.0
13+
github.com/jxskiss/base62 v1.1.0
1214
github.com/lestrrat-go/jwx/v2 v2.1.3
15+
github.com/spf13/cobra v1.9.1
1316
github.com/stretchr/testify v1.10.0
1417
)
1518

@@ -20,9 +23,7 @@ require (
2023
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
2124
github.com/goccy/go-json v0.10.3 // indirect
2225
github.com/google/uuid v1.6.0 // indirect
23-
github.com/goware/base64 v0.1.0 // indirect
2426
github.com/inconshreveable/mousetrap v1.1.0 // indirect
25-
github.com/jxskiss/base62 v1.1.0 // indirect
2627
github.com/kr/text v0.2.0 // indirect
2728
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
2829
github.com/lestrrat-go/httpcc v1.0.1 // indirect
@@ -37,7 +38,6 @@ require (
3738
github.com/prometheus/procfs v0.15.1 // indirect
3839
github.com/rogpeppe/go-internal v1.12.0 // indirect
3940
github.com/segmentio/asm v1.2.0 // indirect
40-
github.com/spf13/cobra v1.9.1 // indirect
4141
github.com/spf13/pflag v1.0.6 // indirect
4242
golang.org/x/crypto v0.31.0 // indirect
4343
golang.org/x/sync v0.10.0 // indirect

go.work.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtX
1515
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
1616
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
1717
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
18+
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
1819
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
1920
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
2021
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

middleware.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
251251

252252
ctx = WithAccessKey(ctx, accessKey)
253253

254-
projectID, _ = GetProjectIDFromAccessKey(accessKey)
254+
projectID, _ = accessKey.GetProjectID()
255255
ctx = withProjectID(ctx, projectID)
256256
httplog.SetAttrs(ctx, slog.Uint64("projectId", projectID))
257257
break
@@ -332,7 +332,7 @@ func PropagateAccessKey(headerContextFuncs ...func(context.Context, http.Header)
332332

333333
if accessKey, ok := GetAccessKey(ctx); ok {
334334
h := http.Header{
335-
HeaderAccessKey: []string{accessKey},
335+
HeaderAccessKey: []string{accessKey.String()},
336336
}
337337
for _, fn := range headerContextFuncs {
338338
ctx, _ = fn(ctx, h)

middleware_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,7 @@ func TestCustomErrHandler(t *testing.T) {
325325

326326
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
327327

328-
var claims map[string]any
329-
claims = map[string]any{"service": "client_service"}
328+
claims := map[string]any{"service": "client_service"}
330329

331330
// Valid Request
332331
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))

0 commit comments

Comments
 (0)