Skip to content

Commit f3b0cb5

Browse files
authored
Merge pull request #118 from xia0pin9/feat/add-context-support
feat(client): add context.Context support for cancellation and timeout
2 parents 761bdb5 + b88a91c commit f3b0cb5

File tree

2 files changed

+304
-13
lines changed

2 files changed

+304
-13
lines changed

client.go

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package knox
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/tls"
67
"encoding/base64"
78
"encoding/json"
@@ -55,12 +56,12 @@ func (c *fileClient) update() error {
5556
var key Key
5657
f, err := os.Open("/var/lib/knox/v0/keys/" + c.keyID)
5758
if err != nil {
58-
return fmt.Errorf("Knox key file err: %s", err.Error())
59+
return fmt.Errorf("knox key file err: %w", err)
5960
}
6061
defer f.Close()
6162
err = json.NewDecoder(f).Decode(&key)
6263
if err != nil {
63-
return fmt.Errorf("Knox json decode err: %s", err.Error())
64+
return fmt.Errorf("knox json decode err: %w", err)
6465
}
6566
c.setValues(&key)
6667
return nil
@@ -73,8 +74,8 @@ func (c *fileClient) setValues(key *Key) {
7374
c.primary = string(key.VersionList.GetPrimary().Data)
7475
ks := key.VersionList.GetActive()
7576
c.active = make([]string, len(ks))
76-
for _, kv := range ks {
77-
c.active = append(c.active, string(kv.Data))
77+
for i, kv := range ks {
78+
c.active[i] = string(kv.Data)
7879
}
7980
}
8081

@@ -107,14 +108,14 @@ func NewFileClient(keyID string) (Client, error) {
107108
}
108109
err = json.Unmarshal(jsonKey, &key)
109110
if err != nil {
110-
return nil, fmt.Errorf("Knox json decode err: %s", err.Error())
111+
return nil, fmt.Errorf("knox json decode err: %w", err)
111112
}
112113
c.setValues(&key)
113114
go func() {
114115
for range time.Tick(refresh) {
115116
err := c.update()
116117
if err != nil {
117-
log.Println("Failed to update knox key ", err.Error())
118+
log.Println("failed to update knox key:", err)
118119
}
119120
}
120121
}()
@@ -187,7 +188,9 @@ type APIClient interface {
187188
AddVersion(keyID string, data []byte) (uint64, error)
188189
UpdateVersion(keyID, versionID string, status VersionStatus) error
189190
CacheGetKey(keyID string) (*Key, error)
191+
CacheGetKeyWithContext(ctx context.Context, keyID string) (*Key, error)
190192
NetworkGetKey(keyID string) (*Key, error)
193+
NetworkGetKeyWithContext(ctx context.Context, keyID string) (*Key, error)
191194
GetKeyWithStatus(keyID string, status VersionStatus) (*Key, error)
192195
CacheGetKeyWithStatus(keyID string, status VersionStatus) (*Key, error)
193196
NetworkGetKeyWithStatus(keyID string, status VersionStatus) (*Key, error)
@@ -220,15 +223,36 @@ func NewClient(host string, client HTTP, authHandlers []AuthHandler, keyFolder,
220223

221224
// CacheGetKey gets the key from file system cache.
222225
func (c *HTTPClient) CacheGetKey(keyID string) (*Key, error) {
226+
return c.CacheGetKeyWithContext(context.Background(), keyID)
227+
}
228+
229+
// CacheGetKeyWithContext gets the key from file system cache with context support.
230+
func (c *HTTPClient) CacheGetKeyWithContext(ctx context.Context, keyID string) (*Key, error) {
223231
if c.KeyFolder == "" {
224232
return nil, fmt.Errorf("no folder set for cached key")
225233
}
226-
path := path.Join(c.KeyFolder, keyID)
227-
b, err := os.ReadFile(path)
234+
235+
// Check context before file operations
236+
select {
237+
case <-ctx.Done():
238+
return nil, ctx.Err()
239+
default:
240+
}
241+
242+
keyPath := path.Join(c.KeyFolder, keyID)
243+
b, err := os.ReadFile(keyPath)
228244
if err != nil {
229245
return nil, err
230246
}
231-
k := Key{Path: path}
247+
248+
// Check context after file read but before JSON unmarshal
249+
select {
250+
case <-ctx.Done():
251+
return nil, ctx.Err()
252+
default:
253+
}
254+
255+
k := Key{Path: keyPath}
232256
err = json.Unmarshal(b, &k)
233257
if err != nil {
234258
return nil, err
@@ -247,6 +271,11 @@ func (c *HTTPClient) NetworkGetKey(keyID string) (*Key, error) {
247271
return c.UncachedClient.NetworkGetKey(keyID)
248272
}
249273

274+
// NetworkGetKeyWithContext gets a knox key by keyID and only uses network without the caches, with context support.
275+
func (c *HTTPClient) NetworkGetKeyWithContext(ctx context.Context, keyID string) (*Key, error) {
276+
return c.UncachedClient.NetworkGetKeyWithContext(ctx, keyID)
277+
}
278+
250279
// GetKey gets a knox key by keyID.
251280
func (c *HTTPClient) GetKey(keyID string) (*Key, error) {
252281
key, err := c.CacheGetKey(keyID)
@@ -364,8 +393,13 @@ func NewUncachedClient(host string, client HTTP, authHandlers []AuthHandler, ver
364393

365394
// NetworkGetKey gets a knox key by keyID and only uses network without the caches.
366395
func (c *UncachedHTTPClient) NetworkGetKey(keyID string) (*Key, error) {
396+
return c.NetworkGetKeyWithContext(context.Background(), keyID)
397+
}
398+
399+
// NetworkGetKeyWithContext gets a knox key by keyID and only uses network without the caches, with context support.
400+
func (c *UncachedHTTPClient) NetworkGetKeyWithContext(ctx context.Context, keyID string) (*Key, error) {
367401
key := &Key{}
368-
err := c.getHTTPData("GET", "/v0/keys/"+keyID+"/", nil, key)
402+
err := c.getHTTPDataWithContext(ctx, "GET", "/v0/keys/"+keyID+"/", nil, key)
369403
if err != nil {
370404
return nil, err
371405
}
@@ -380,7 +414,12 @@ func (c *UncachedHTTPClient) NetworkGetKey(keyID string) (*Key, error) {
380414

381415
// CacheGetKey acts same as NetworkGetKey for UncachedHTTPClient.
382416
func (c *UncachedHTTPClient) CacheGetKey(keyID string) (*Key, error) {
383-
return c.NetworkGetKey(keyID)
417+
return c.CacheGetKeyWithContext(context.Background(), keyID)
418+
}
419+
420+
// CacheGetKeyWithContext acts same as NetworkGetKeyWithContext for UncachedHTTPClient.
421+
func (c *UncachedHTTPClient) CacheGetKeyWithContext(ctx context.Context, keyID string) (*Key, error) {
422+
return c.NetworkGetKeyWithContext(ctx, keyID)
384423
}
385424

386425
// GetKey gets a knox key by keyID.
@@ -494,6 +533,10 @@ func (c *UncachedHTTPClient) getClient() (HTTP, error) {
494533
}
495534

496535
func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Values, data interface{}) error {
536+
return c.getHTTPDataWithContext(context.Background(), method, path, body, data)
537+
}
538+
539+
func (c *UncachedHTTPClient) getHTTPDataWithContext(ctx context.Context, method string, path string, body url.Values, data interface{}) error {
497540
if len(c.AuthHandlers) == 0 {
498541
return errNoAuth
499542
}
@@ -502,6 +545,13 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
502545
attemptedAuthTypes := []string{}
503546

504547
for _, authHandler := range c.AuthHandlers {
548+
// Check context before each auth handler attempt
549+
select {
550+
case <-ctx.Done():
551+
return ctx.Err()
552+
default:
553+
}
554+
505555
authToken, authType, clientOverride := authHandler()
506556
if authToken == "" {
507557
continue
@@ -511,7 +561,7 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
511561

512562
// Create the request per authHandler to prevent body from being reused between requests.
513563
// This is due to the body being non-reusable after the first read.
514-
r, err := http.NewRequest(method, "https://"+c.Host+path, bytes.NewBufferString(body.Encode()))
564+
r, err := http.NewRequestWithContext(ctx, method, "https://"+c.Host+path, bytes.NewBufferString(body.Encode()))
515565
if err != nil {
516566
return err
517567
}
@@ -538,6 +588,13 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
538588
resp.Data = data
539589
// Contains retry logic if we decode a 500 error.
540590
for i := 1; i <= maxRetryAttempts; i++ {
591+
// Check context before each retry attempt
592+
select {
593+
case <-ctx.Done():
594+
return ctx.Err()
595+
default:
596+
}
597+
541598
err = getHTTPResp(cli, r, resp)
542599
if err != nil {
543600
return err
@@ -552,7 +609,18 @@ func (c *UncachedHTTPClient) getHTTPData(method string, path string, body url.Va
552609
// If we get a 500, we need to retry the request.
553610
return fmt.Errorf(resp.Message)
554611
}
555-
time.Sleep(GetBackoffDuration(i))
612+
613+
// Check context before sleeping
614+
backoffDuration := GetBackoffDuration(i)
615+
timer := time.NewTimer(backoffDuration)
616+
select {
617+
case <-ctx.Done():
618+
timer.Stop()
619+
return ctx.Err()
620+
case <-timer.C:
621+
timer.Stop()
622+
// Continue to retry
623+
}
556624
}
557625
} else {
558626
// If we got a successful response, we can return the data.

0 commit comments

Comments
 (0)