@@ -2,6 +2,7 @@ package knox
22
33import (
44 "bytes"
5+ "context"
56 "crypto/tls"
67 "encoding/base64"
78 "encoding/json"
@@ -55,12 +56,12 @@ func (c *fileClient) update() error {
5556 var key Key
5657 f , err := os .Open ("/var/lib/knox/v0/keys/" + c .keyID )
5758 if err != nil {
58- return fmt .Errorf ("Knox key file err: %s " , err . Error () )
59+ return fmt .Errorf ("knox key file err: %w " , err )
5960 }
6061 defer f .Close ()
6162 err = json .NewDecoder (f ).Decode (& key )
6263 if err != nil {
63- return fmt .Errorf ("Knox json decode err: %s " , err . Error () )
64+ return fmt .Errorf ("knox json decode err: %w " , err )
6465 }
6566 c .setValues (& key )
6667 return nil
@@ -73,8 +74,8 @@ func (c *fileClient) setValues(key *Key) {
7374 c .primary = string (key .VersionList .GetPrimary ().Data )
7475 ks := key .VersionList .GetActive ()
7576 c .active = make ([]string , len (ks ))
76- for _ , kv := range ks {
77- c .active = append ( c . active , string (kv .Data ) )
77+ for i , kv := range ks {
78+ c .active [ i ] = string (kv .Data )
7879 }
7980}
8081
@@ -107,14 +108,14 @@ func NewFileClient(keyID string) (Client, error) {
107108 }
108109 err = json .Unmarshal (jsonKey , & key )
109110 if err != nil {
110- return nil , fmt .Errorf ("Knox json decode err: %s " , err . Error () )
111+ return nil , fmt .Errorf ("knox json decode err: %w " , err )
111112 }
112113 c .setValues (& key )
113114 go func () {
114115 for range time .Tick (refresh ) {
115116 err := c .update ()
116117 if err != nil {
117- log .Println ("Failed to update knox key " , err . Error () )
118+ log .Println ("failed to update knox key: " , err )
118119 }
119120 }
120121 }()
@@ -187,7 +188,9 @@ type APIClient interface {
187188 AddVersion (keyID string , data []byte ) (uint64 , error )
188189 UpdateVersion (keyID , versionID string , status VersionStatus ) error
189190 CacheGetKey (keyID string ) (* Key , error )
191+ CacheGetKeyWithContext (ctx context.Context , keyID string ) (* Key , error )
190192 NetworkGetKey (keyID string ) (* Key , error )
193+ NetworkGetKeyWithContext (ctx context.Context , keyID string ) (* Key , error )
191194 GetKeyWithStatus (keyID string , status VersionStatus ) (* Key , error )
192195 CacheGetKeyWithStatus (keyID string , status VersionStatus ) (* Key , error )
193196 NetworkGetKeyWithStatus (keyID string , status VersionStatus ) (* Key , error )
@@ -220,15 +223,36 @@ func NewClient(host string, client HTTP, authHandlers []AuthHandler, keyFolder,
220223
221224// CacheGetKey gets the key from file system cache.
222225func (c * HTTPClient ) CacheGetKey (keyID string ) (* Key , error ) {
226+ return c .CacheGetKeyWithContext (context .Background (), keyID )
227+ }
228+
229+ // CacheGetKeyWithContext gets the key from file system cache with context support.
230+ func (c * HTTPClient ) CacheGetKeyWithContext (ctx context.Context , keyID string ) (* Key , error ) {
223231 if c .KeyFolder == "" {
224232 return nil , fmt .Errorf ("no folder set for cached key" )
225233 }
226- path := path .Join (c .KeyFolder , keyID )
227- b , err := os .ReadFile (path )
234+
235+ // Check context before file operations
236+ select {
237+ case <- ctx .Done ():
238+ return nil , ctx .Err ()
239+ default :
240+ }
241+
242+ keyPath := path .Join (c .KeyFolder , keyID )
243+ b , err := os .ReadFile (keyPath )
228244 if err != nil {
229245 return nil , err
230246 }
231- k := Key {Path : path }
247+
248+ // Check context after file read but before JSON unmarshal
249+ select {
250+ case <- ctx .Done ():
251+ return nil , ctx .Err ()
252+ default :
253+ }
254+
255+ k := Key {Path : keyPath }
232256 err = json .Unmarshal (b , & k )
233257 if err != nil {
234258 return nil , err
@@ -247,6 +271,11 @@ func (c *HTTPClient) NetworkGetKey(keyID string) (*Key, error) {
247271 return c .UncachedClient .NetworkGetKey (keyID )
248272}
249273
274+ // NetworkGetKeyWithContext gets a knox key by keyID and only uses network without the caches, with context support.
275+ func (c * HTTPClient ) NetworkGetKeyWithContext (ctx context.Context , keyID string ) (* Key , error ) {
276+ return c .UncachedClient .NetworkGetKeyWithContext (ctx , keyID )
277+ }
278+
250279// GetKey gets a knox key by keyID.
251280func (c * HTTPClient ) GetKey (keyID string ) (* Key , error ) {
252281 key , err := c .CacheGetKey (keyID )
@@ -364,8 +393,13 @@ func NewUncachedClient(host string, client HTTP, authHandlers []AuthHandler, ver
364393
365394// NetworkGetKey gets a knox key by keyID and only uses network without the caches.
366395func (c * UncachedHTTPClient ) NetworkGetKey (keyID string ) (* Key , error ) {
396+ return c .NetworkGetKeyWithContext (context .Background (), keyID )
397+ }
398+
399+ // NetworkGetKeyWithContext gets a knox key by keyID and only uses network without the caches, with context support.
400+ func (c * UncachedHTTPClient ) NetworkGetKeyWithContext (ctx context.Context , keyID string ) (* Key , error ) {
367401 key := & Key {}
368- err := c .getHTTPData ( "GET" , "/v0/keys/" + keyID + "/" , nil , key )
402+ err := c .getHTTPDataWithContext ( ctx , "GET" , "/v0/keys/" + keyID + "/" , nil , key )
369403 if err != nil {
370404 return nil , err
371405 }
@@ -380,7 +414,12 @@ func (c *UncachedHTTPClient) NetworkGetKey(keyID string) (*Key, error) {
380414
381415// CacheGetKey acts same as NetworkGetKey for UncachedHTTPClient.
382416func (c * UncachedHTTPClient ) CacheGetKey (keyID string ) (* Key , error ) {
383- return c .NetworkGetKey (keyID )
417+ return c .CacheGetKeyWithContext (context .Background (), keyID )
418+ }
419+
420+ // CacheGetKeyWithContext acts same as NetworkGetKeyWithContext for UncachedHTTPClient.
421+ func (c * UncachedHTTPClient ) CacheGetKeyWithContext (ctx context.Context , keyID string ) (* Key , error ) {
422+ return c .NetworkGetKeyWithContext (ctx , keyID )
384423}
385424
386425// GetKey gets a knox key by keyID.
@@ -494,6 +533,10 @@ func (c *UncachedHTTPClient) getClient() (HTTP, error) {
494533}
495534
496535func (c * UncachedHTTPClient ) getHTTPData (method string , path string , body url.Values , data interface {}) error {
536+ return c .getHTTPDataWithContext (context .Background (), method , path , body , data )
537+ }
538+
539+ func (c * UncachedHTTPClient ) getHTTPDataWithContext (ctx context.Context , method string , path string , body url.Values , data interface {}) error {
497540 if len (c .AuthHandlers ) == 0 {
498541 return errNoAuth
499542 }
@@ -502,6 +545,13 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
502545 attemptedAuthTypes := []string {}
503546
504547 for _ , authHandler := range c .AuthHandlers {
548+ // Check context before each auth handler attempt
549+ select {
550+ case <- ctx .Done ():
551+ return ctx .Err ()
552+ default :
553+ }
554+
505555 authToken , authType , clientOverride := authHandler ()
506556 if authToken == "" {
507557 continue
@@ -511,7 +561,7 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
511561
512562 // Create the request per authHandler to prevent body from being reused between requests.
513563 // This is due to the body being non-reusable after the first read.
514- r , err := http .NewRequest ( method , "https://" + c .Host + path , bytes .NewBufferString (body .Encode ()))
564+ r , err := http .NewRequestWithContext ( ctx , method , "https://" + c .Host + path , bytes .NewBufferString (body .Encode ()))
515565 if err != nil {
516566 return err
517567 }
@@ -538,6 +588,13 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
538588 resp .Data = data
539589 // Contains retry logic if we decode a 500 error.
540590 for i := 1 ; i <= maxRetryAttempts ; i ++ {
591+ // Check context before each retry attempt
592+ select {
593+ case <- ctx .Done ():
594+ return ctx .Err ()
595+ default :
596+ }
597+
541598 err = getHTTPResp (cli , r , resp )
542599 if err != nil {
543600 return err
@@ -552,7 +609,18 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
552609 // If we get a 500, we need to retry the request.
553610 return fmt .Errorf (resp .Message )
554611 }
555- time .Sleep (GetBackoffDuration (i ))
612+
613+ // Check context before sleeping
614+ backoffDuration := GetBackoffDuration (i )
615+ timer := time .NewTimer (backoffDuration )
616+ select {
617+ case <- ctx .Done ():
618+ timer .Stop ()
619+ return ctx .Err ()
620+ case <- timer .C :
621+ timer .Stop ()
622+ // Continue to retry
623+ }
556624 }
557625 } else {
558626 // If we got a successful response, we can return the data.
0 commit comments