Skip to content

Commit 7dc2de7

Browse files
committed
Add GetPrefix method to AccessKey and update tests for consistency
1 parent f1231e3 commit 7dc2de7

File tree

4 files changed

+26
-27
lines changed

4 files changed

+26
-27
lines changed

access_key.go

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"net/http"
87
"strings"
98

109
"crypto/rand"
1110
"encoding/binary"
1211

13-
"github.com/go-chi/transport"
1412
"github.com/goware/base64"
1513
"github.com/jxskiss/base62"
1614
)
@@ -43,6 +41,14 @@ func (a AccessKey) GetProjectID() (projectID uint64, err error) {
4341
return 0, errors.Join(errs...)
4442
}
4543

44+
func (a AccessKey) GetPrefix() string {
45+
parts := strings.Split(a.String(), Separator)
46+
if len(parts) < 2 {
47+
return ""
48+
}
49+
return strings.Join(parts[:len(parts)-1], Separator)
50+
}
51+
4652
func NewAccessKey(ctx context.Context, projectID uint64) AccessKey {
4753
version, ok := GetVersion(ctx)
4854
if !ok {
@@ -57,26 +63,6 @@ func NewAccessKey(ctx context.Context, projectID uint64) AccessKey {
5763
return ""
5864
}
5965

60-
func GetAccessKeyPrefix(accessKey AccessKey) string {
61-
parts := strings.Split(accessKey.String(), Separator)
62-
if len(parts) < 2 {
63-
return ""
64-
}
65-
return strings.Join(parts[:len(parts)-1], Separator)
66-
}
67-
68-
func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
69-
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
70-
r := transport.CloneRequest(req)
71-
72-
if accessKey, ok := GetAccessKey(req.Context()); ok {
73-
r.Header.Set(HeaderAccessKey, accessKey)
74-
}
75-
76-
return next.RoundTrip(r)
77-
})
78-
}
79-
8066
type Encoding interface {
8167
Version() byte
8268
Encode(ctx context.Context, projectID uint64) AccessKey

access_key_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ func TestAccessKeyEncoding(t *testing.T) {
3535
ctx := authcontrol.WithVersion(context.Background(), 2)
3636
projectID := uint64(12345)
3737
accessKey := authcontrol.NewAccessKey(ctx, projectID)
38-
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
38+
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
3939
outID, err := accessKey.GetProjectID()
4040
require.NoError(t, err)
4141
require.Equal(t, projectID, outID)
4242

4343
ctx = authcontrol.WithPrefix(ctx, "newprefix:dev")
4444

4545
accessKey2 := authcontrol.NewAccessKey(ctx, projectID)
46-
t.Log("=> k", accessKey2, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey2))
46+
t.Log("=> k", accessKey2, "| prefix =>", accessKey2.GetPrefix())
4747
outID, err = accessKey2.GetProjectID()
4848
require.NoError(t, err)
4949
require.Equal(t, projectID, outID)
@@ -57,7 +57,7 @@ func TestAccessKeyEncoding(t *testing.T) {
5757
func TestDecode(t *testing.T) {
5858
ctx := authcontrol.WithVersion(context.Background(), 2)
5959
accessKey := authcontrol.NewAccessKey(ctx, 237)
60-
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
60+
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
6161
}
6262

6363
func TestForwardAccessKeyTransport(t *testing.T) {

common.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/0xsequence/authcontrol/proto"
1616
"github.com/go-chi/jwtauth/v5"
17+
"github.com/go-chi/transport"
1718
"github.com/lestrrat-go/jwx/v2/jwa"
1819
"github.com/lestrrat-go/jwx/v2/jwt"
1920
)
@@ -198,3 +199,16 @@ func findProjectClaim(r *http.Request) (uint64, error) {
198199
return 0, fmt.Errorf("invalid type: %T", val)
199200
}
200201
}
202+
203+
// ForwardAccessKeyTransport is a RoundTripper that forwards the access key from the request context to the request header.
204+
func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
205+
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
206+
r := transport.CloneRequest(req)
207+
208+
if accessKey, ok := GetAccessKey(req.Context()); ok {
209+
r.Header.Set(HeaderAccessKey, accessKey.String())
210+
}
211+
212+
return next.RoundTrip(r)
213+
})
214+
}

middleware_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,7 @@ func TestCustomErrHandler(t *testing.T) {
325325

326326
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
327327

328-
var claims map[string]any
329-
claims = map[string]any{"service": "client_service"}
328+
claims := map[string]any{"service": "client_service"}
330329

331330
// Valid Request
332331
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))

0 commit comments

Comments
 (0)