Skip to content

Commit 42e426f

Browse files
authored
Revert "Introduce AccessKey type for improved type safety (#46)" (#49)
This reverts commit 1a44627.
1 parent aca750a commit 42e426f

File tree

10 files changed

+74
-87
lines changed

10 files changed

+74
-87
lines changed

access_key.go

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"net/http"
78
"strings"
89

910
"crypto/rand"
1011
"encoding/binary"
1112

13+
"github.com/go-chi/transport"
1214
"github.com/goware/base64"
1315
"github.com/jxskiss/base62"
1416
)
@@ -22,16 +24,10 @@ var (
2224
ErrInvalidKeyLength = errors.New("invalid access key length")
2325
)
2426

25-
type AccessKey string
26-
27-
func (a AccessKey) String() string {
28-
return string(a)
29-
}
30-
31-
func (a AccessKey) GetProjectID() (projectID uint64, err error) {
27+
func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) {
3228
var errs []error
3329
for _, e := range SupportedEncodings {
34-
projectID, err := e.Decode(a)
30+
projectID, err := e.Decode(accessKey)
3531
if err != nil {
3632
errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err))
3733
continue
@@ -41,34 +37,44 @@ func (a AccessKey) GetProjectID() (projectID uint64, err error) {
4137
return 0, errors.Join(errs...)
4238
}
4339

44-
func (a AccessKey) GetPrefix() string {
45-
parts := strings.Split(a.String(), Separator)
46-
if len(parts) < 2 {
47-
return ""
48-
}
49-
return strings.Join(parts[:len(parts)-1], Separator)
50-
}
51-
52-
var ErrUnsupportedEncoding = errors.New("unsupported access key encoding")
53-
54-
func GenerateAccessKey(ctx context.Context, projectID uint64) (AccessKey, error) {
40+
func GenerateAccessKey(ctx context.Context, projectID uint64) string {
5541
version, ok := GetVersion(ctx)
5642
if !ok {
57-
return DefaultEncoding.Encode(ctx, projectID), nil
43+
return DefaultEncoding.Encode(ctx, projectID)
5844
}
5945

6046
for _, e := range SupportedEncodings {
6147
if e.Version() == version {
62-
return e.Encode(ctx, projectID), nil
48+
return e.Encode(ctx, projectID)
6349
}
6450
}
65-
return "", ErrUnsupportedEncoding
51+
return ""
52+
}
53+
54+
func GetAccessKeyPrefix(accessKey string) string {
55+
parts := strings.Split(accessKey, Separator)
56+
if len(parts) < 2 {
57+
return ""
58+
}
59+
return strings.Join(parts[:len(parts)-1], Separator)
60+
}
61+
62+
func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
63+
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
64+
r := transport.CloneRequest(req)
65+
66+
if accessKey, ok := GetAccessKey(req.Context()); ok {
67+
r.Header.Set(HeaderAccessKey, accessKey)
68+
}
69+
70+
return next.RoundTrip(r)
71+
})
6672
}
6773

6874
type Encoding interface {
6975
Version() byte
70-
Encode(ctx context.Context, projectID uint64) AccessKey
71-
Decode(accessKey AccessKey) (projectID uint64, err error)
76+
Encode(ctx context.Context, projectID uint64) string
77+
Decode(accessKey string) (projectID uint64, err error)
7278
}
7379

7480
const (
@@ -83,15 +89,15 @@ type V0 struct{}
8389

8490
func (V0) Version() byte { return 0 }
8591

86-
func (V0) Encode(_ context.Context, projectID uint64) AccessKey {
92+
func (V0) Encode(_ context.Context, projectID uint64) string {
8793
buf := make([]byte, sizeV0)
8894
binary.BigEndian.PutUint64(buf, projectID)
8995
_, _ = rand.Read(buf[8:])
90-
return AccessKey(base62.EncodeToString(buf))
96+
return base62.EncodeToString(buf)
9197
}
9298

93-
func (V0) Decode(accessKey AccessKey) (projectID uint64, err error) {
94-
buf, err := base62.DecodeString(accessKey.String())
99+
func (V0) Decode(accessKey string) (projectID uint64, err error) {
100+
buf, err := base62.DecodeString(accessKey)
95101
if err != nil {
96102
return 0, fmt.Errorf("base62 decode: %w", err)
97103
}
@@ -107,16 +113,16 @@ type V1 struct{}
107113

108114
func (V1) Version() byte { return 1 }
109115

110-
func (v V1) Encode(_ context.Context, projectID uint64) AccessKey {
116+
func (v V1) Encode(_ context.Context, projectID uint64) string {
111117
buf := make([]byte, sizeV1)
112118
buf[0] = v.Version()
113119
binary.BigEndian.PutUint64(buf[1:], projectID)
114120
_, _ = rand.Read(buf[9:])
115-
return AccessKey(base64.Base64UrlEncode(buf))
121+
return base64.Base64UrlEncode(buf)
116122
}
117123

118-
func (V1) Decode(accessKey AccessKey) (projectID uint64, err error) {
119-
buf, err := base64.Base64UrlDecode(accessKey.String())
124+
func (V1) Decode(accessKey string) (projectID uint64, err error) {
125+
buf, err := base64.Base64UrlDecode(accessKey)
120126
if err != nil {
121127
return 0, fmt.Errorf("base64 decode: %w", err)
122128
}
@@ -137,19 +143,19 @@ const (
137143

138144
func (V2) Version() byte { return 2 }
139145

140-
func (v V2) Encode(ctx context.Context, projectID uint64) AccessKey {
146+
func (v V2) Encode(ctx context.Context, projectID uint64) string {
141147
buf := make([]byte, sizeV2)
142148
buf[0] = v.Version()
143149
binary.BigEndian.PutUint64(buf[1:], projectID)
144150
_, _ = rand.Read(buf[9:])
145-
return AccessKey(getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf))
151+
return getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf)
146152
}
147153

148-
func (V2) Decode(accessKey AccessKey) (projectID uint64, err error) {
149-
parts := strings.Split(accessKey.String(), Separator)
150-
raw := parts[len(parts)-1]
154+
func (V2) Decode(accessKey string) (projectID uint64, err error) {
155+
parts := strings.Split(accessKey, Separator)
156+
accessKey = parts[len(parts)-1]
151157

152-
buf, err := base64.Base64UrlDecode(raw)
158+
buf, err := base64.Base64UrlDecode(accessKey)
153159
if err != nil {
154160
return 0, fmt.Errorf("base64 decode: %w", err)
155161
}

access_key_test.go

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,50 @@ func TestAccessKeyEncoding(t *testing.T) {
1414
t.Run("v0", func(t *testing.T) {
1515
ctx := authcontrol.WithVersion(context.Background(), 0)
1616
projectID := uint64(12345)
17-
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
18-
require.NoError(t, err)
17+
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
1918
t.Log("=> k", accessKey)
2019

21-
outID, err := accessKey.GetProjectID()
20+
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
2221
require.NoError(t, err)
2322
require.Equal(t, projectID, outID)
2423
})
2524

2625
t.Run("v1", func(t *testing.T) {
2726
ctx := authcontrol.WithVersion(context.Background(), 1)
2827
projectID := uint64(12345)
29-
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
30-
require.NoError(t, err)
28+
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
3129
t.Log("=> k", accessKey)
32-
outID, err := accessKey.GetProjectID()
30+
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
3331
require.NoError(t, err)
3432
require.Equal(t, projectID, outID)
3533
})
3634
t.Run("v2", func(t *testing.T) {
3735
ctx := authcontrol.WithVersion(context.Background(), 2)
3836
projectID := uint64(12345)
39-
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
40-
require.NoError(t, err)
41-
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
42-
outID, err := accessKey.GetProjectID()
37+
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
38+
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
39+
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
4340
require.NoError(t, err)
4441
require.Equal(t, projectID, outID)
4542

4643
ctx = authcontrol.WithPrefix(ctx, "newprefix:dev")
4744

48-
accessKey2, err := authcontrol.GenerateAccessKey(ctx, projectID)
49-
require.NoError(t, err)
50-
t.Log("=> k", accessKey2, "| prefix =>", accessKey2.GetPrefix())
51-
outID, err = accessKey2.GetProjectID()
45+
accessKey2 := authcontrol.GenerateAccessKey(ctx, projectID)
46+
t.Log("=> k", accessKey2, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey2))
47+
outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey2)
5248
require.NoError(t, err)
5349
require.Equal(t, projectID, outID)
5450
// retrocompatibility with the older prefix
55-
outID, err = accessKey.GetProjectID()
51+
outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey)
5652
require.NoError(t, err)
5753
require.Equal(t, projectID, outID)
5854
})
5955
}
6056

6157
func TestDecode(t *testing.T) {
6258
ctx := authcontrol.WithVersion(context.Background(), 2)
63-
accessKey, err := authcontrol.GenerateAccessKey(ctx, 237)
64-
require.NoError(t, err)
65-
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
59+
accessKey := authcontrol.GenerateAccessKey(ctx, 237)
60+
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
6661
}
6762

6863
func TestForwardAccessKeyTransport(t *testing.T) {
@@ -76,7 +71,7 @@ func TestForwardAccessKeyTransport(t *testing.T) {
7671

7772
// Create context with access key
7873
accessKey := "test-access-key-123"
79-
ctx := authcontrol.WithAccessKey(context.Background(), authcontrol.AccessKey(accessKey))
74+
ctx := authcontrol.WithAccessKey(context.Background(), accessKey)
8075

8176
// Create HTTP client with ForwardAccessKeyTransport
8277
client := &http.Client{

cmd/access_key/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ var decodeCmd = &cobra.Command{
2828
if len(args) != 1 {
2929
return fmt.Errorf("access key is required")
3030
}
31-
accessKey := authcontrol.AccessKey(args[0])
31+
accessKey := args[0]
3232
var (
3333
projectID uint64
3434
version byte

common.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414

1515
"github.com/0xsequence/authcontrol/proto"
1616
"github.com/go-chi/jwtauth/v5"
17-
"github.com/go-chi/transport"
1817
"github.com/lestrrat-go/jwx/v2/jwa"
1918
"github.com/lestrrat-go/jwx/v2/jwt"
2019
)
@@ -23,10 +22,10 @@ const (
2322
HeaderAccessKey = "X-Access-Key"
2423
)
2524

26-
type AccessKeyFunc func(*http.Request) AccessKey
25+
type AccessKeyFunc func(*http.Request) string
2726

28-
func AccessKeyFromHeader(r *http.Request) AccessKey {
29-
return AccessKey(r.Header.Get(HeaderAccessKey))
27+
func AccessKeyFromHeader(r *http.Request) string {
28+
return r.Header.Get(HeaderAccessKey)
3029
}
3130

3231
type ErrHandler func(r *http.Request, w http.ResponseWriter, err error)
@@ -199,16 +198,3 @@ func findProjectClaim(r *http.Request) (uint64, error) {
199198
return 0, fmt.Errorf("invalid type: %T", val)
200199
}
201200
}
202-
203-
// ForwardAccessKeyTransport is a RoundTripper that forwards the access key from the request context to the request header.
204-
func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
205-
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
206-
r := transport.CloneRequest(req)
207-
208-
if accessKey, ok := GetAccessKey(req.Context()); ok {
209-
r.Header.Set(HeaderAccessKey, accessKey.String())
210-
}
211-
212-
return next.RoundTrip(r)
213-
})
214-
}

common_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ import (
1717

1818
const HeaderKey = "Test-Key"
1919

20-
func keyFunc(r *http.Request) authcontrol.AccessKey {
21-
return authcontrol.AccessKey(r.Header.Get(HeaderKey))
20+
func keyFunc(r *http.Request) string {
21+
return r.Header.Get(HeaderKey)
2222
}
2323

2424
type requestOption func(r *http.Request)

context.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ func GetService(ctx context.Context) (string, bool) {
102102
// WithAccessKey adds the access key to the context.
103103
//
104104
// TODO: Deprecate this in favor of Session middleware with a JWT token.
105-
func WithAccessKey(ctx context.Context, accessKey AccessKey) context.Context {
105+
func WithAccessKey(ctx context.Context, accessKey string) context.Context {
106106
return context.WithValue(ctx, ctxKeyAccessKey, accessKey)
107107
}
108108

109109
// GetAccessKey returns the access key from the context.
110-
func GetAccessKey(ctx context.Context) (AccessKey, bool) {
111-
v, ok := ctx.Value(ctxKeyAccessKey).(AccessKey)
110+
func GetAccessKey(ctx context.Context) (string, bool) {
111+
v, ok := ctx.Value(ctxKeyAccessKey).(string)
112112
return v, ok
113113
}
114114

go.mod

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ require (
99
github.com/go-chi/metrics v0.1.0
1010
github.com/go-chi/traceid v0.2.0
1111
github.com/go-chi/transport v0.4.0
12-
github.com/goware/base64 v0.1.0
13-
github.com/jxskiss/base62 v1.1.0
1412
github.com/lestrrat-go/jwx/v2 v2.1.3
15-
github.com/spf13/cobra v1.9.1
1613
github.com/stretchr/testify v1.10.0
1714
)
1815

@@ -23,7 +20,9 @@ require (
2320
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
2421
github.com/goccy/go-json v0.10.3 // indirect
2522
github.com/google/uuid v1.6.0 // indirect
23+
github.com/goware/base64 v0.1.0 // indirect
2624
github.com/inconshreveable/mousetrap v1.1.0 // indirect
25+
github.com/jxskiss/base62 v1.1.0 // indirect
2726
github.com/kr/text v0.2.0 // indirect
2827
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
2928
github.com/lestrrat-go/httpcc v1.0.1 // indirect
@@ -38,6 +37,7 @@ require (
3837
github.com/prometheus/procfs v0.15.1 // indirect
3938
github.com/rogpeppe/go-internal v1.12.0 // indirect
4039
github.com/segmentio/asm v1.2.0 // indirect
40+
github.com/spf13/cobra v1.9.1 // indirect
4141
github.com/spf13/pflag v1.0.6 // indirect
4242
golang.org/x/crypto v0.31.0 // indirect
4343
golang.org/x/sync v0.10.0 // indirect

go.work.sum

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtX
1515
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
1616
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
1717
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
18-
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
1918
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
2019
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
2120
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

middleware.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
251251

252252
ctx = WithAccessKey(ctx, accessKey)
253253

254-
projectID, _ = accessKey.GetProjectID()
254+
projectID, _ = GetProjectIDFromAccessKey(accessKey)
255255
ctx = withProjectID(ctx, projectID)
256256
httplog.SetAttrs(ctx, slog.Uint64("projectId", projectID))
257257
break
@@ -332,7 +332,7 @@ func PropagateAccessKey(headerContextFuncs ...func(context.Context, http.Header)
332332

333333
if accessKey, ok := GetAccessKey(ctx); ok {
334334
h := http.Header{
335-
HeaderAccessKey: []string{accessKey.String()},
335+
HeaderAccessKey: []string{accessKey},
336336
}
337337
for _, fn := range headerContextFuncs {
338338
ctx, _ = fn(ctx, h)

middleware_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ func TestCustomErrHandler(t *testing.T) {
325325

326326
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
327327

328-
claims := map[string]any{"service": "client_service"}
328+
var claims map[string]any
329+
claims = map[string]any{"service": "client_service"}
329330

330331
// Valid Request
331332
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))

0 commit comments

Comments
 (0)