@@ -4,13 +4,11 @@ import (
44 "context"
55 "errors"
66 "fmt"
7- "net/http"
87 "strings"
98
109 "crypto/rand"
1110 "encoding/binary"
1211
13- "github.com/go-chi/transport"
1412 "github.com/goware/base64"
1513 "github.com/jxskiss/base62"
1614)
@@ -24,10 +22,16 @@ var (
2422 ErrInvalidKeyLength = errors .New ("invalid access key length" )
2523)
2624
27- func GetProjectIDFromAccessKey (accessKey string ) (projectID uint64 , err error ) {
25+ type AccessKey string
26+
27+ func (a AccessKey ) String () string {
28+ return string (a )
29+ }
30+
31+ func (a AccessKey ) GetProjectID () (projectID uint64 , err error ) {
2832 var errs []error
2933 for _ , e := range SupportedEncodings {
30- projectID , err := e .Decode (accessKey )
34+ projectID , err := e .Decode (a )
3135 if err != nil {
3236 errs = append (errs , fmt .Errorf ("decode v%d: %w" , e .Version (), err ))
3337 continue
@@ -37,44 +41,34 @@ func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) {
3741 return 0 , errors .Join (errs ... )
3842}
3943
40- func GenerateAccessKey (ctx context.Context , projectID uint64 ) string {
41- version , ok := GetVersion (ctx )
42- if ! ok {
43- return DefaultEncoding .Encode (ctx , projectID )
44- }
45-
46- for _ , e := range SupportedEncodings {
47- if e .Version () == version {
48- return e .Encode (ctx , projectID )
49- }
50- }
51- return ""
52- }
53-
54- func GetAccessKeyPrefix (accessKey string ) string {
55- parts := strings .Split (accessKey , Separator )
44+ func (a AccessKey ) GetPrefix () string {
45+ parts := strings .Split (a .String (), Separator )
5646 if len (parts ) < 2 {
5747 return ""
5848 }
5949 return strings .Join (parts [:len (parts )- 1 ], Separator )
6050}
6151
62- func ForwardAccessKeyTransport (next http.RoundTripper ) http.RoundTripper {
63- return transport .RoundTripFunc (func (req * http.Request ) (resp * http.Response , err error ) {
64- r := transport .CloneRequest (req )
52+ var ErrUnsupportedEncoding = errors .New ("unsupported access key encoding" )
6553
66- if accessKey , ok := GetAccessKey (req .Context ()); ok {
67- r .Header .Set (HeaderAccessKey , accessKey )
68- }
54+ func GenerateAccessKey (ctx context.Context , projectID uint64 ) (AccessKey , error ) {
55+ version , ok := GetVersion (ctx )
56+ if ! ok {
57+ return DefaultEncoding .Encode (ctx , projectID ), nil
58+ }
6959
70- return next .RoundTrip (r )
71- })
60+ for _ , e := range SupportedEncodings {
61+ if e .Version () == version {
62+ return e .Encode (ctx , projectID ), nil
63+ }
64+ }
65+ return "" , ErrUnsupportedEncoding
7266}
7367
7468type Encoding interface {
7569 Version () byte
76- Encode (ctx context.Context , projectID uint64 ) string
77- Decode (accessKey string ) (projectID uint64 , err error )
70+ Encode (ctx context.Context , projectID uint64 ) AccessKey
71+ Decode (accessKey AccessKey ) (projectID uint64 , err error )
7872}
7973
8074const (
@@ -89,15 +83,15 @@ type V0 struct{}
8983
9084func (V0 ) Version () byte { return 0 }
9185
92- func (V0 ) Encode (_ context.Context , projectID uint64 ) string {
86+ func (V0 ) Encode (_ context.Context , projectID uint64 ) AccessKey {
9387 buf := make ([]byte , sizeV0 )
9488 binary .BigEndian .PutUint64 (buf , projectID )
9589 _ , _ = rand .Read (buf [8 :])
96- return base62 .EncodeToString (buf )
90+ return AccessKey ( base62 .EncodeToString (buf ) )
9791}
9892
99- func (V0 ) Decode (accessKey string ) (projectID uint64 , err error ) {
100- buf , err := base62 .DecodeString (accessKey )
93+ func (V0 ) Decode (accessKey AccessKey ) (projectID uint64 , err error ) {
94+ buf , err := base62 .DecodeString (accessKey . String () )
10195 if err != nil {
10296 return 0 , fmt .Errorf ("base62 decode: %w" , err )
10397 }
@@ -113,16 +107,16 @@ type V1 struct{}
113107
114108func (V1 ) Version () byte { return 1 }
115109
116- func (v V1 ) Encode (_ context.Context , projectID uint64 ) string {
110+ func (v V1 ) Encode (_ context.Context , projectID uint64 ) AccessKey {
117111 buf := make ([]byte , sizeV1 )
118112 buf [0 ] = v .Version ()
119113 binary .BigEndian .PutUint64 (buf [1 :], projectID )
120114 _ , _ = rand .Read (buf [9 :])
121- return base64 .Base64UrlEncode (buf )
115+ return AccessKey ( base64 .Base64UrlEncode (buf ) )
122116}
123117
124- func (V1 ) Decode (accessKey string ) (projectID uint64 , err error ) {
125- buf , err := base64 .Base64UrlDecode (accessKey )
118+ func (V1 ) Decode (accessKey AccessKey ) (projectID uint64 , err error ) {
119+ buf , err := base64 .Base64UrlDecode (accessKey . String () )
126120 if err != nil {
127121 return 0 , fmt .Errorf ("base64 decode: %w" , err )
128122 }
@@ -143,19 +137,19 @@ const (
143137
144138func (V2 ) Version () byte { return 2 }
145139
146- func (v V2 ) Encode (ctx context.Context , projectID uint64 ) string {
140+ func (v V2 ) Encode (ctx context.Context , projectID uint64 ) AccessKey {
147141 buf := make ([]byte , sizeV2 )
148142 buf [0 ] = v .Version ()
149143 binary .BigEndian .PutUint64 (buf [1 :], projectID )
150144 _ , _ = rand .Read (buf [9 :])
151- return getPrefix (ctx ) + Separator + base64 .Base64UrlEncode (buf )
145+ return AccessKey ( getPrefix (ctx ) + Separator + base64 .Base64UrlEncode (buf ) )
152146}
153147
154- func (V2 ) Decode (accessKey string ) (projectID uint64 , err error ) {
155- parts := strings .Split (accessKey , Separator )
156- accessKey = parts [len (parts )- 1 ]
148+ func (V2 ) Decode (accessKey AccessKey ) (projectID uint64 , err error ) {
149+ parts := strings .Split (accessKey . String () , Separator )
150+ raw : = parts [len (parts )- 1 ]
157151
158- buf , err := base64 .Base64UrlDecode (accessKey )
152+ buf , err := base64 .Base64UrlDecode (raw )
159153 if err != nil {
160154 return 0 , fmt .Errorf ("base64 decode: %w" , err )
161155 }
0 commit comments