diff --git a/access_key.go b/access_key.go index 99a23ec..15cd072 100644 --- a/access_key.go +++ b/access_key.go @@ -4,13 +4,11 @@ import ( "context" "errors" "fmt" - "net/http" "strings" "crypto/rand" "encoding/binary" - "github.com/go-chi/transport" "github.com/goware/base64" "github.com/jxskiss/base62" ) @@ -24,10 +22,16 @@ var ( ErrInvalidKeyLength = errors.New("invalid access key length") ) -func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) { +type AccessKey string + +func (a AccessKey) String() string { + return string(a) +} + +func (a AccessKey) GetProjectID() (projectID uint64, err error) { var errs []error for _, e := range SupportedEncodings { - projectID, err := e.Decode(accessKey) + projectID, err := e.Decode(a) if err != nil { errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err)) continue @@ -37,44 +41,34 @@ func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) { return 0, errors.Join(errs...) } -func GenerateAccessKey(ctx context.Context, projectID uint64) string { - version, ok := GetVersion(ctx) - if !ok { - return DefaultEncoding.Encode(ctx, projectID) - } - - for _, e := range SupportedEncodings { - if e.Version() == version { - return e.Encode(ctx, projectID) - } - } - return "" -} - -func GetAccessKeyPrefix(accessKey string) string { - parts := strings.Split(accessKey, Separator) +func (a AccessKey) GetPrefix() string { + parts := strings.Split(a.String(), Separator) if len(parts) < 2 { return "" } return strings.Join(parts[:len(parts)-1], Separator) } -func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper { - return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { - r := transport.CloneRequest(req) +var ErrUnsupportedEncoding = errors.New("unsupported access key encoding") - if accessKey, ok := GetAccessKey(req.Context()); ok { - r.Header.Set(HeaderAccessKey, accessKey) - } +func GenerateAccessKey(ctx context.Context, projectID uint64) (AccessKey, error) { + version, ok := GetVersion(ctx) + if !ok { + return DefaultEncoding.Encode(ctx, projectID), nil + } - return next.RoundTrip(r) - }) + for _, e := range SupportedEncodings { + if e.Version() == version { + return e.Encode(ctx, projectID), nil + } + } + return "", ErrUnsupportedEncoding } type Encoding interface { Version() byte - Encode(ctx context.Context, projectID uint64) string - Decode(accessKey string) (projectID uint64, err error) + Encode(ctx context.Context, projectID uint64) AccessKey + Decode(accessKey AccessKey) (projectID uint64, err error) } const ( @@ -89,15 +83,15 @@ type V0 struct{} func (V0) Version() byte { return 0 } -func (V0) Encode(_ context.Context, projectID uint64) string { +func (V0) Encode(_ context.Context, projectID uint64) AccessKey { buf := make([]byte, sizeV0) binary.BigEndian.PutUint64(buf, projectID) _, _ = rand.Read(buf[8:]) - return base62.EncodeToString(buf) + return AccessKey(base62.EncodeToString(buf)) } -func (V0) Decode(accessKey string) (projectID uint64, err error) { - buf, err := base62.DecodeString(accessKey) +func (V0) Decode(accessKey AccessKey) (projectID uint64, err error) { + buf, err := base62.DecodeString(accessKey.String()) if err != nil { return 0, fmt.Errorf("base62 decode: %w", err) } @@ -113,16 +107,16 @@ type V1 struct{} func (V1) Version() byte { return 1 } -func (v V1) Encode(_ context.Context, projectID uint64) string { +func (v V1) Encode(_ context.Context, projectID uint64) AccessKey { buf := make([]byte, sizeV1) buf[0] = v.Version() binary.BigEndian.PutUint64(buf[1:], projectID) _, _ = rand.Read(buf[9:]) - return base64.Base64UrlEncode(buf) + return AccessKey(base64.Base64UrlEncode(buf)) } -func (V1) Decode(accessKey string) (projectID uint64, err error) { - buf, err := base64.Base64UrlDecode(accessKey) +func (V1) Decode(accessKey AccessKey) (projectID uint64, err error) { + buf, err := base64.Base64UrlDecode(accessKey.String()) if err != nil { return 0, fmt.Errorf("base64 decode: %w", err) } @@ -143,19 +137,19 @@ const ( func (V2) Version() byte { return 2 } -func (v V2) Encode(ctx context.Context, projectID uint64) string { +func (v V2) Encode(ctx context.Context, projectID uint64) AccessKey { buf := make([]byte, sizeV2) buf[0] = v.Version() binary.BigEndian.PutUint64(buf[1:], projectID) _, _ = rand.Read(buf[9:]) - return getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf) + return AccessKey(getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf)) } -func (V2) Decode(accessKey string) (projectID uint64, err error) { - parts := strings.Split(accessKey, Separator) - accessKey = parts[len(parts)-1] +func (V2) Decode(accessKey AccessKey) (projectID uint64, err error) { + parts := strings.Split(accessKey.String(), Separator) + raw := parts[len(parts)-1] - buf, err := base64.Base64UrlDecode(accessKey) + buf, err := base64.Base64UrlDecode(raw) if err != nil { return 0, fmt.Errorf("base64 decode: %w", err) } diff --git a/access_key_test.go b/access_key_test.go index 4f98249..2c27497 100644 --- a/access_key_test.go +++ b/access_key_test.go @@ -14,10 +14,11 @@ func TestAccessKeyEncoding(t *testing.T) { t.Run("v0", func(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 0) projectID := uint64(12345) - accessKey := authcontrol.GenerateAccessKey(ctx, projectID) + accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID) + require.NoError(t, err) t.Log("=> k", accessKey) - outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) + outID, err := accessKey.GetProjectID() require.NoError(t, err) require.Equal(t, projectID, outID) }) @@ -25,30 +26,33 @@ func TestAccessKeyEncoding(t *testing.T) { t.Run("v1", func(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 1) projectID := uint64(12345) - accessKey := authcontrol.GenerateAccessKey(ctx, projectID) + accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID) + require.NoError(t, err) t.Log("=> k", accessKey) - outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) + outID, err := accessKey.GetProjectID() require.NoError(t, err) require.Equal(t, projectID, outID) }) t.Run("v2", func(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 2) projectID := uint64(12345) - accessKey := authcontrol.GenerateAccessKey(ctx, projectID) - t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey)) - outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) + accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID) + require.NoError(t, err) + t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix()) + outID, err := accessKey.GetProjectID() require.NoError(t, err) require.Equal(t, projectID, outID) ctx = authcontrol.WithPrefix(ctx, "newprefix:dev") - accessKey2 := authcontrol.GenerateAccessKey(ctx, projectID) - t.Log("=> k", accessKey2, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey2)) - outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey2) + accessKey2, err := authcontrol.GenerateAccessKey(ctx, projectID) + require.NoError(t, err) + t.Log("=> k", accessKey2, "| prefix =>", accessKey2.GetPrefix()) + outID, err = accessKey2.GetProjectID() require.NoError(t, err) require.Equal(t, projectID, outID) // retrocompatibility with the older prefix - outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey) + outID, err = accessKey.GetProjectID() require.NoError(t, err) require.Equal(t, projectID, outID) }) @@ -56,8 +60,9 @@ func TestAccessKeyEncoding(t *testing.T) { func TestDecode(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 2) - accessKey := authcontrol.GenerateAccessKey(ctx, 237) - t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey)) + accessKey, err := authcontrol.GenerateAccessKey(ctx, 237) + require.NoError(t, err) + t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix()) } func TestForwardAccessKeyTransport(t *testing.T) { @@ -71,7 +76,7 @@ func TestForwardAccessKeyTransport(t *testing.T) { // Create context with access key accessKey := "test-access-key-123" - ctx := authcontrol.WithAccessKey(context.Background(), accessKey) + ctx := authcontrol.WithAccessKey(context.Background(), authcontrol.AccessKey(accessKey)) // Create HTTP client with ForwardAccessKeyTransport client := &http.Client{ diff --git a/cmd/access_key/main.go b/cmd/access_key/main.go index 152b1d5..9b1b125 100644 --- a/cmd/access_key/main.go +++ b/cmd/access_key/main.go @@ -28,7 +28,7 @@ var decodeCmd = &cobra.Command{ if len(args) != 1 { return fmt.Errorf("access key is required") } - accessKey := args[0] + accessKey := authcontrol.AccessKey(args[0]) var ( projectID uint64 version byte diff --git a/common.go b/common.go index 1dc3b8a..8d4cebe 100644 --- a/common.go +++ b/common.go @@ -14,6 +14,7 @@ import ( "github.com/0xsequence/authcontrol/proto" "github.com/go-chi/jwtauth/v5" + "github.com/go-chi/transport" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" ) @@ -22,10 +23,10 @@ const ( HeaderAccessKey = "X-Access-Key" ) -type AccessKeyFunc func(*http.Request) string +type AccessKeyFunc func(*http.Request) AccessKey -func AccessKeyFromHeader(r *http.Request) string { - return r.Header.Get(HeaderAccessKey) +func AccessKeyFromHeader(r *http.Request) AccessKey { + return AccessKey(r.Header.Get(HeaderAccessKey)) } type ErrHandler func(r *http.Request, w http.ResponseWriter, err error) @@ -198,3 +199,16 @@ func findProjectClaim(r *http.Request) (uint64, error) { return 0, fmt.Errorf("invalid type: %T", val) } } + +// ForwardAccessKeyTransport is a RoundTripper that forwards the access key from the request context to the request header. +func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper { + return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { + r := transport.CloneRequest(req) + + if accessKey, ok := GetAccessKey(req.Context()); ok { + r.Header.Set(HeaderAccessKey, accessKey.String()) + } + + return next.RoundTrip(r) + }) +} diff --git a/common_test.go b/common_test.go index 17a1e19..b456779 100644 --- a/common_test.go +++ b/common_test.go @@ -17,8 +17,8 @@ import ( const HeaderKey = "Test-Key" -func keyFunc(r *http.Request) string { - return r.Header.Get(HeaderKey) +func keyFunc(r *http.Request) authcontrol.AccessKey { + return authcontrol.AccessKey(r.Header.Get(HeaderKey)) } type requestOption func(r *http.Request) diff --git a/context.go b/context.go index b16a1e2..da76542 100644 --- a/context.go +++ b/context.go @@ -102,13 +102,13 @@ func GetService(ctx context.Context) (string, bool) { // WithAccessKey adds the access key to the context. // // TODO: Deprecate this in favor of Session middleware with a JWT token. -func WithAccessKey(ctx context.Context, accessKey string) context.Context { +func WithAccessKey(ctx context.Context, accessKey AccessKey) context.Context { return context.WithValue(ctx, ctxKeyAccessKey, accessKey) } // GetAccessKey returns the access key from the context. -func GetAccessKey(ctx context.Context) (string, bool) { - v, ok := ctx.Value(ctxKeyAccessKey).(string) +func GetAccessKey(ctx context.Context) (AccessKey, bool) { + v, ok := ctx.Value(ctxKeyAccessKey).(AccessKey) return v, ok } diff --git a/go.mod b/go.mod index 5195ab9..36c8dc2 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,10 @@ require ( github.com/go-chi/metrics v0.1.0 github.com/go-chi/traceid v0.2.0 github.com/go-chi/transport v0.4.0 + github.com/goware/base64 v0.1.0 + github.com/jxskiss/base62 v1.1.0 github.com/lestrrat-go/jwx/v2 v2.1.3 + github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 ) @@ -20,9 +23,7 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/goware/base64 v0.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jxskiss/base62 v1.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect @@ -37,7 +38,6 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/segmentio/asm v1.2.0 // indirect - github.com/spf13/cobra v1.9.1 // indirect github.com/spf13/pflag v1.0.6 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/sync v0.10.0 // indirect diff --git a/go.work.sum b/go.work.sum index 4c8099d..01f0bef 100644 --- a/go.work.sum +++ b/go.work.sum @@ -15,6 +15,7 @@ github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtX golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/middleware.go b/middleware.go index 338a3f6..fbfebe0 100644 --- a/middleware.go +++ b/middleware.go @@ -251,7 +251,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler { ctx = WithAccessKey(ctx, accessKey) - projectID, _ = GetProjectIDFromAccessKey(accessKey) + projectID, _ = accessKey.GetProjectID() ctx = withProjectID(ctx, projectID) httplog.SetAttrs(ctx, slog.Uint64("projectId", projectID)) break @@ -332,7 +332,7 @@ func PropagateAccessKey(headerContextFuncs ...func(context.Context, http.Header) if accessKey, ok := GetAccessKey(ctx); ok { h := http.Header{ - HeaderAccessKey: []string{accessKey}, + HeaderAccessKey: []string{accessKey.String()}, } for _, fn := range headerContextFuncs { ctx, _ = fn(ctx, h) diff --git a/middleware_test.go b/middleware_test.go index 147a23c..2e73571 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -325,8 +325,7 @@ func TestCustomErrHandler(t *testing.T) { r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - var claims map[string]any - claims = map[string]any{"service": "client_service"} + claims := map[string]any{"service": "client_service"} // Valid Request ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))