diff --git a/access_key.go b/access_key.go new file mode 100644 index 0000000..33cf9bd --- /dev/null +++ b/access_key.go @@ -0,0 +1,152 @@ +package authcontrol + +import ( + "context" + "errors" + "fmt" + "strings" + + "crypto/rand" + "encoding/binary" + + "github.com/goware/base64" + "github.com/jxskiss/base62" +) + +var ( + // SupportedEncodings is a list of supported encodings. If more versions of the same version are added, the first one will be used. + SupportedEncodings = []Encoding{V2{}, V1{}, V0{}} + + DefaultEncoding Encoding = V1{} + + ErrInvalidKeyLength = errors.New("invalid access key length") +) + +func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) { + var errs []error + for _, e := range SupportedEncodings { + projectID, err := e.Decode(accessKey) + if err != nil { + errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err)) + continue + } + return projectID, nil + } + return 0, errors.Join(errs...) +} + +func GenerateAccessKey(ctx context.Context, projectID uint64) string { + version, ok := GetVersion(ctx) + if !ok { + return DefaultEncoding.Encode(ctx, projectID) + } + + for _, e := range SupportedEncodings { + if e.Version() == version { + return e.Encode(ctx, projectID) + } + } + return "" +} + +func GetAccessKeyPrefix(accessKey string) string { + parts := strings.Split(accessKey, Separator) + if len(parts) < 2 { + return "" + } + return strings.Join(parts[:len(parts)-1], Separator) +} + +type Encoding interface { + Version() byte + Encode(ctx context.Context, projectID uint64) string + Decode(accessKey string) (projectID uint64, err error) +} + +const ( + sizeV0 = 24 + sizeV1 = 26 + sizeV2 = 32 +) + +// V0: base62 encoded, 24-byte fixed length. 8 bytes for project ID, rest random. +// Uses custom base62, limiting cross-language compatibility. +type V0 struct{} + +func (V0) Version() byte { return 0 } + +func (V0) Encode(_ context.Context, projectID uint64) string { + buf := make([]byte, sizeV0) + binary.BigEndian.PutUint64(buf, projectID) + _, _ = rand.Read(buf[8:]) + return base62.EncodeToString(buf) +} + +func (V0) Decode(accessKey string) (projectID uint64, err error) { + buf, err := base62.DecodeString(accessKey) + if err != nil { + return 0, fmt.Errorf("base62 decode: %w", err) + } + if len(buf) != sizeV0 { + return 0, ErrInvalidKeyLength + } + return binary.BigEndian.Uint64(buf[:8]), nil +} + +// V1: base64 encoded, 26-byte fixed length. 1 byte for version, 8 bytes for project ID, rest random. +// Uses standard base64url Compatible with other systems. +type V1 struct{} + +func (V1) Version() byte { return 1 } + +func (v V1) Encode(_ context.Context, projectID uint64) string { + buf := make([]byte, sizeV1) + buf[0] = v.Version() + binary.BigEndian.PutUint64(buf[1:], projectID) + _, _ = rand.Read(buf[9:]) + return base64.Base64UrlEncode(buf) +} + +func (V1) Decode(accessKey string) (projectID uint64, err error) { + buf, err := base64.Base64UrlDecode(accessKey) + if err != nil { + return 0, fmt.Errorf("base64 decode: %w", err) + } + if len(buf) != sizeV1 { + return 0, ErrInvalidKeyLength + } + return binary.BigEndian.Uint64(buf[1:9]), nil +} + +// V2: base64 encoded, 32-byte fixed length. 1 byte for version, 8 bytes for project ID, rest random. +// Uses ":" as separator between prefix and base64 encoded data. +type V2 struct{} + +const ( + Separator = ":" + DefaultPrefix = "seq" +) + +func (V2) Version() byte { return 2 } + +func (v V2) Encode(ctx context.Context, projectID uint64) string { + buf := make([]byte, sizeV2) + buf[0] = v.Version() + binary.BigEndian.PutUint64(buf[1:], projectID) + _, _ = rand.Read(buf[9:]) + return getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf) +} + +func (V2) Decode(accessKey string) (projectID uint64, err error) { + parts := strings.Split(accessKey, Separator) + accessKey = parts[len(parts)-1] + + buf, err := base64.Base64UrlDecode(accessKey) + if err != nil { + return 0, fmt.Errorf("base64 decode: %w", err) + } + if len(buf) != sizeV2 { + return 0, ErrInvalidKeyLength + } + return binary.BigEndian.Uint64(buf[1:9]), nil +} diff --git a/access_key_test.go b/access_key_test.go new file mode 100644 index 0000000..6f68b58 --- /dev/null +++ b/access_key_test.go @@ -0,0 +1,59 @@ +package authcontrol_test + +import ( + "context" + "testing" + + "github.com/0xsequence/authcontrol" + "github.com/stretchr/testify/require" +) + +func TestAccessKeyEncoding(t *testing.T) { + t.Run("v0", func(t *testing.T) { + ctx := authcontrol.WithVersion(context.Background(), 0) + projectID := uint64(12345) + accessKey := authcontrol.GenerateAccessKey(ctx, projectID) + t.Log("=> k", accessKey) + + outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) + require.NoError(t, err) + require.Equal(t, projectID, outID) + }) + + t.Run("v1", func(t *testing.T) { + ctx := authcontrol.WithVersion(context.Background(), 1) + projectID := uint64(12345) + accessKey := authcontrol.GenerateAccessKey(ctx, projectID) + t.Log("=> k", accessKey) + outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) + require.NoError(t, err) + require.Equal(t, projectID, outID) + }) + t.Run("v2", func(t *testing.T) { + ctx := authcontrol.WithVersion(context.Background(), 2) + projectID := uint64(12345) + accessKey := authcontrol.GenerateAccessKey(ctx, projectID) + t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey)) + outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) + require.NoError(t, err) + require.Equal(t, projectID, outID) + + ctx = authcontrol.WithPrefix(ctx, "newprefix:dev") + + accessKey2 := authcontrol.GenerateAccessKey(ctx, projectID) + t.Log("=> k", accessKey2, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey2)) + outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey2) + require.NoError(t, err) + require.Equal(t, projectID, outID) + // retrocompatibility with the older prefix + outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey) + require.NoError(t, err) + require.Equal(t, projectID, outID) + }) +} + +func TestDecode(t *testing.T) { + ctx := authcontrol.WithVersion(context.Background(), 2) + accessKey := authcontrol.GenerateAccessKey(ctx, 237) + t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey)) +} diff --git a/cmd/access_key/main.go b/cmd/access_key/main.go new file mode 100644 index 0000000..152b1d5 --- /dev/null +++ b/cmd/access_key/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "fmt" + "os" + + "github.com/0xsequence/authcontrol" + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "authcontrol", + Short: "Access Keys CLI", + Long: `A command line interface for managing access keys.`, +} + +var accessKeyCmd = &cobra.Command{ + Use: "access-key", + Short: "Manage access keys", + Long: `Generate and decode access key.`, +} + +var decodeCmd = &cobra.Command{ + Use: "decode", + Short: "Decode an access key", + Long: `Decode an access key to retrieve the project ID.`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 { + return fmt.Errorf("access key is required") + } + accessKey := args[0] + var ( + projectID uint64 + version byte + errs []error + ) + for _, e := range authcontrol.SupportedEncodings { + id, err := e.Decode(accessKey) + if err != nil { + errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err)) + continue + } + projectID = id + version = e.Version() + break + } + + if len(errs) == len(authcontrol.SupportedEncodings) { + return fmt.Errorf("failed to decode access key: %v", errs) + } + fmt.Println("Version: ", version) + fmt.Println("Project: ", projectID) + fmt.Println("AccessKey:", accessKey) + return nil + }, +} + +func init() { + accessKeyCmd.AddCommand(decodeCmd) + rootCmd.AddCommand(accessKeyCmd) +} + +func main() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/common_test.go b/common_test.go index 9c09257..17a1e19 100644 --- a/common_test.go +++ b/common_test.go @@ -54,10 +54,10 @@ func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, pat handler.ServeHTTP(rr, req.WithContext(ctx)) if status := rr.Result().StatusCode; status < http.StatusOK || status >= http.StatusBadRequest { - w := proto.WebRPCError{} - err = json.Unmarshal(rr.Body.Bytes(), &w) - require.NoError(t, err) - return false, w + webrpcErr := proto.WebRPCError{} + err = json.Unmarshal(rr.Body.Bytes(), &webrpcErr) + require.NoError(t, err, "failed to unmarshal response body: %s", rr.Body.Bytes()) + return false, webrpcErr } return true, nil diff --git a/context.go b/context.go index f2e026a..b16a1e2 100644 --- a/context.go +++ b/context.go @@ -22,6 +22,8 @@ var ( ctxKeyAccessKey = &contextKey{"AccessKey"} ctxKeyProjectID = &contextKey{"ProjectID"} ctxKeyProject = &contextKey{"Project"} + ctxKeyPrefix = &contextKey{"Prefix"} + ctxKeyVersion = &contextKey{"Version"} ) // @@ -140,3 +142,29 @@ func GetProject[T any](ctx context.Context) (*T, bool) { v, ok := ctx.Value(ctxKeyProject).(*T) return v, ok } + +// Access Key + +// WithPrefix sets the prefix to the context. +func WithPrefix(ctx context.Context, prefix string) context.Context { + return context.WithValue(ctx, ctxKeyPrefix, prefix) +} + +// getPrefix returns the prefix from the context. If not set, it returns DefaultPrefix. +func getPrefix(ctx context.Context) string { + if v, _ := ctx.Value(ctxKeyPrefix).(string); v != "" { + return v + } + return DefaultPrefix +} + +// WithVersion sets the version to the context. +func WithVersion(ctx context.Context, version byte) context.Context { + return context.WithValue(ctx, ctxKeyVersion, version) +} + +// GetVersion returns the version from the context. If not set, it returns AccessKeyVersion. +func GetVersion(ctx context.Context) (byte, bool) { + v, ok := ctx.Value(ctxKeyVersion).(byte) + return v, ok +} diff --git a/go.mod b/go.mod index 2a29abd..5195ab9 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,9 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/goware/base64 v0.1.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jxskiss/base62 v1.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect @@ -34,6 +37,8 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/segmentio/asm v1.2.0 // indirect + github.com/spf13/cobra v1.9.1 // indirect + github.com/spf13/pflag v1.0.6 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.30.0 // indirect diff --git a/go.sum b/go.sum index 57fb7e3..36cd85b 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -26,6 +27,12 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/goware/base64 v0.1.0 h1:ehxtbkFiMMqfjh+WQ9jBKnpBXywuZfw3gePMr19MgWo= +github.com/goware/base64 v0.1.0/go.mod h1:8stO8YzeBOn5KTtFI4yBaQ2ZewlNUgSG1QGyWbNsPvw= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jxskiss/base62 v1.1.0 h1:A5zbF8v8WXx2xixnAKD2w+abC+sIzYJX+nxmhA6HWFw= +github.com/jxskiss/base62 v1.1.0/go.mod h1:HhWAlUXvxKThfOlZbcuFzsqwtF5TcqS9ru3y5GfjWAc= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -60,8 +67,13 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/middleware_test.go b/middleware_test.go index 02fd6b2..147a23c 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -57,7 +57,7 @@ func TestSession(t *testing.T) { MethodProject: authcontrol.NewACL(proto.SessionType_Project.OrHigher()...), MethodUser: authcontrol.NewACL(proto.SessionType_User.OrHigher()...), MethodAdmin: authcontrol.NewACL(proto.SessionType_Admin.OrHigher()...), - MethodService: authcontrol.NewACL(proto.SessionType_InternalService.OrHigher()...), + MethodService: authcontrol.NewACL(proto.SessionType_S2S.OrHigher()...), }} const ( @@ -100,13 +100,13 @@ func TestSession(t *testing.T) { {Session: proto.SessionType_User, Admin: true}, {Session: proto.SessionType_Admin}, {Session: proto.SessionType_Admin, AccessKey: AccessKey}, - {Session: proto.SessionType_InternalService}, - {Session: proto.SessionType_InternalService, AccessKey: AccessKey}, + {Session: proto.SessionType_S2S}, + {Session: proto.SessionType_S2S, AccessKey: AccessKey}, } for service := range ACLConfig { for _, method := range Methods { - types := ACLConfig[service][method] + expectedACL := ACLConfig[service][method] for _, tc := range testCases { s := strings.Builder{} fmt.Fprintf(&s, "%s/%s", method, tc.Session) @@ -131,7 +131,7 @@ func TestSession(t *testing.T) { claims = map[string]any{"account": address} case proto.SessionType_Admin: claims = map[string]any{"account": WalletAddress, "admin": true} - case proto.SessionType_InternalService: + case proto.SessionType_S2S: claims = map[string]any{"service": ServiceName} } @@ -143,8 +143,6 @@ func TestSession(t *testing.T) { options = append(options, jwt(authcontrol.S2SToken(JWTSecret, claims))) } - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), options...) - session := tc.Session switch { case session == proto.SessionType_User && tc.Admin: @@ -153,7 +151,8 @@ func TestSession(t *testing.T) { session = proto.SessionType_AccessKey } - if !types.Includes(session) { + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), options...) + if !expectedACL.Includes(session) { assert.Error(t, err) assert.False(t, ok) return