@@ -4,11 +4,13 @@ import (
44 "context"
55 "errors"
66 "fmt"
7+ "net/http"
78 "strings"
89
910 "crypto/rand"
1011 "encoding/binary"
1112
13+ "github.com/go-chi/transport"
1214 "github.com/goware/base64"
1315 "github.com/jxskiss/base62"
1416)
@@ -22,16 +24,10 @@ var (
2224 ErrInvalidKeyLength = errors .New ("invalid access key length" )
2325)
2426
25- type AccessKey string
26-
27- func (a AccessKey ) String () string {
28- return string (a )
29- }
30-
31- func (a AccessKey ) GetProjectID () (projectID uint64 , err error ) {
27+ func GetProjectIDFromAccessKey (accessKey string ) (projectID uint64 , err error ) {
3228 var errs []error
3329 for _ , e := range SupportedEncodings {
34- projectID , err := e .Decode (a )
30+ projectID , err := e .Decode (accessKey )
3531 if err != nil {
3632 errs = append (errs , fmt .Errorf ("decode v%d: %w" , e .Version (), err ))
3733 continue
@@ -41,34 +37,44 @@ func (a AccessKey) GetProjectID() (projectID uint64, err error) {
4137 return 0 , errors .Join (errs ... )
4238}
4339
44- func (a AccessKey ) GetPrefix () string {
45- parts := strings .Split (a .String (), Separator )
46- if len (parts ) < 2 {
47- return ""
48- }
49- return strings .Join (parts [:len (parts )- 1 ], Separator )
50- }
51-
52- var ErrUnsupportedEncoding = errors .New ("unsupported access key encoding" )
53-
54- func GenerateAccessKey (ctx context.Context , projectID uint64 ) (AccessKey , error ) {
40+ func GenerateAccessKey (ctx context.Context , projectID uint64 ) string {
5541 version , ok := GetVersion (ctx )
5642 if ! ok {
57- return DefaultEncoding .Encode (ctx , projectID ), nil
43+ return DefaultEncoding .Encode (ctx , projectID )
5844 }
5945
6046 for _ , e := range SupportedEncodings {
6147 if e .Version () == version {
62- return e .Encode (ctx , projectID ), nil
48+ return e .Encode (ctx , projectID )
6349 }
6450 }
65- return "" , ErrUnsupportedEncoding
51+ return ""
52+ }
53+
54+ func GetAccessKeyPrefix (accessKey string ) string {
55+ parts := strings .Split (accessKey , Separator )
56+ if len (parts ) < 2 {
57+ return ""
58+ }
59+ return strings .Join (parts [:len (parts )- 1 ], Separator )
60+ }
61+
62+ func ForwardAccessKeyTransport (next http.RoundTripper ) http.RoundTripper {
63+ return transport .RoundTripFunc (func (req * http.Request ) (resp * http.Response , err error ) {
64+ r := transport .CloneRequest (req )
65+
66+ if accessKey , ok := GetAccessKey (req .Context ()); ok {
67+ r .Header .Set (HeaderAccessKey , accessKey )
68+ }
69+
70+ return next .RoundTrip (r )
71+ })
6672}
6773
6874type Encoding interface {
6975 Version () byte
70- Encode (ctx context.Context , projectID uint64 ) AccessKey
71- Decode (accessKey AccessKey ) (projectID uint64 , err error )
76+ Encode (ctx context.Context , projectID uint64 ) string
77+ Decode (accessKey string ) (projectID uint64 , err error )
7278}
7379
7480const (
@@ -83,15 +89,15 @@ type V0 struct{}
8389
8490func (V0 ) Version () byte { return 0 }
8591
86- func (V0 ) Encode (_ context.Context , projectID uint64 ) AccessKey {
92+ func (V0 ) Encode (_ context.Context , projectID uint64 ) string {
8793 buf := make ([]byte , sizeV0 )
8894 binary .BigEndian .PutUint64 (buf , projectID )
8995 _ , _ = rand .Read (buf [8 :])
90- return AccessKey ( base62 .EncodeToString (buf ) )
96+ return base62 .EncodeToString (buf )
9197}
9298
93- func (V0 ) Decode (accessKey AccessKey ) (projectID uint64 , err error ) {
94- buf , err := base62 .DecodeString (accessKey . String () )
99+ func (V0 ) Decode (accessKey string ) (projectID uint64 , err error ) {
100+ buf , err := base62 .DecodeString (accessKey )
95101 if err != nil {
96102 return 0 , fmt .Errorf ("base62 decode: %w" , err )
97103 }
@@ -107,16 +113,16 @@ type V1 struct{}
107113
108114func (V1 ) Version () byte { return 1 }
109115
110- func (v V1 ) Encode (_ context.Context , projectID uint64 ) AccessKey {
116+ func (v V1 ) Encode (_ context.Context , projectID uint64 ) string {
111117 buf := make ([]byte , sizeV1 )
112118 buf [0 ] = v .Version ()
113119 binary .BigEndian .PutUint64 (buf [1 :], projectID )
114120 _ , _ = rand .Read (buf [9 :])
115- return AccessKey ( base64 .Base64UrlEncode (buf ) )
121+ return base64 .Base64UrlEncode (buf )
116122}
117123
118- func (V1 ) Decode (accessKey AccessKey ) (projectID uint64 , err error ) {
119- buf , err := base64 .Base64UrlDecode (accessKey . String () )
124+ func (V1 ) Decode (accessKey string ) (projectID uint64 , err error ) {
125+ buf , err := base64 .Base64UrlDecode (accessKey )
120126 if err != nil {
121127 return 0 , fmt .Errorf ("base64 decode: %w" , err )
122128 }
@@ -137,19 +143,19 @@ const (
137143
138144func (V2 ) Version () byte { return 2 }
139145
140- func (v V2 ) Encode (ctx context.Context , projectID uint64 ) AccessKey {
146+ func (v V2 ) Encode (ctx context.Context , projectID uint64 ) string {
141147 buf := make ([]byte , sizeV2 )
142148 buf [0 ] = v .Version ()
143149 binary .BigEndian .PutUint64 (buf [1 :], projectID )
144150 _ , _ = rand .Read (buf [9 :])
145- return AccessKey ( getPrefix (ctx ) + Separator + base64 .Base64UrlEncode (buf ) )
151+ return getPrefix (ctx ) + Separator + base64 .Base64UrlEncode (buf )
146152}
147153
148- func (V2 ) Decode (accessKey AccessKey ) (projectID uint64 , err error ) {
149- parts := strings .Split (accessKey . String () , Separator )
150- raw : = parts [len (parts )- 1 ]
154+ func (V2 ) Decode (accessKey string ) (projectID uint64 , err error ) {
155+ parts := strings .Split (accessKey , Separator )
156+ accessKey = parts [len (parts )- 1 ]
151157
152- buf , err := base64 .Base64UrlDecode (raw )
158+ buf , err := base64 .Base64UrlDecode (accessKey )
153159 if err != nil {
154160 return 0 , fmt .Errorf ("base64 decode: %w" , err )
155161 }
0 commit comments