Skip to content

Commit b88a91c

Browse files
committed
Add CacheGetKeyWithContext interface
1 parent 886472b commit b88a91c

File tree

2 files changed

+229
-0
lines changed

2 files changed

+229
-0
lines changed

client.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ type APIClient interface {
190190
CacheGetKey(keyID string) (*Key, error)
191191
CacheGetKeyWithContext(ctx context.Context, keyID string) (*Key, error)
192192
NetworkGetKey(keyID string) (*Key, error)
193+
NetworkGetKeyWithContext(ctx context.Context, keyID string) (*Key, error)
193194
GetKeyWithStatus(keyID string, status VersionStatus) (*Key, error)
194195
CacheGetKeyWithStatus(keyID string, status VersionStatus) (*Key, error)
195196
NetworkGetKeyWithStatus(keyID string, status VersionStatus) (*Key, error)
@@ -270,6 +271,11 @@ func (c *HTTPClient) NetworkGetKey(keyID string) (*Key, error) {
270271
return c.UncachedClient.NetworkGetKey(keyID)
271272
}
272273

274+
// NetworkGetKeyWithContext gets a knox key by keyID and only uses network without the caches, with context support.
275+
func (c *HTTPClient) NetworkGetKeyWithContext(ctx context.Context, keyID string) (*Key, error) {
276+
return c.UncachedClient.NetworkGetKeyWithContext(ctx, keyID)
277+
}
278+
273279
// GetKey gets a knox key by keyID.
274280
func (c *HTTPClient) GetKey(keyID string) (*Key, error) {
275281
key, err := c.CacheGetKey(keyID)

client_test.go

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package knox
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/tls"
67
"encoding/json"
78
"errors"
@@ -15,6 +16,7 @@ import (
1516
"strings"
1617
"sync/atomic"
1718
"testing"
19+
"time"
1820
)
1921

2022
type mockHTTPClient struct {
@@ -717,3 +719,224 @@ func TestNewFileClient(t *testing.T) {
717719
t.Fatalf("Expected error starting with '%s', got: %v", expectedPrefix, err)
718720
}
719721
}
722+
723+
func TestCacheGetKeyWithContext(t *testing.T) {
724+
expected := Key{
725+
ID: "testkey",
726+
ACL: ACL([]Access{}),
727+
VersionList: KeyVersionList{},
728+
VersionHash: "VersionHash",
729+
}
730+
731+
keyBytes, err := json.Marshal(expected)
732+
if err != nil {
733+
t.Fatalf("Error marshalling key: %s", err)
734+
}
735+
736+
tempDir := t.TempDir()
737+
err = os.WriteFile(path.Join(tempDir, "testkey"), keyBytes, 0600)
738+
if err != nil {
739+
t.Fatalf("Failed to write test key: %s", err)
740+
}
741+
742+
cli := &HTTPClient{
743+
KeyFolder: tempDir,
744+
UncachedClient: &UncachedHTTPClient{},
745+
}
746+
747+
// Test with valid context
748+
ctx := context.Background()
749+
k, err := cli.CacheGetKeyWithContext(ctx, "testkey")
750+
if err != nil {
751+
t.Fatalf("Unexpected error: %s", err)
752+
}
753+
if k.ID != expected.ID {
754+
t.Fatalf("Expected ID %s, got %s", expected.ID, k.ID)
755+
}
756+
757+
// Test with canceled context
758+
canceledCtx, cancel := context.WithCancel(context.Background())
759+
cancel()
760+
_, err = cli.CacheGetKeyWithContext(canceledCtx, "testkey")
761+
if !errors.Is(err, context.Canceled) {
762+
t.Fatalf("Expected context.Canceled error, got: %v", err)
763+
}
764+
}
765+
766+
func TestNetworkGetKeyWithContext(t *testing.T) {
767+
expected := Key{
768+
ID: "testkey",
769+
ACL: ACL([]Access{}),
770+
VersionList: KeyVersionList{},
771+
VersionHash: "VersionHash",
772+
}
773+
resp, err := buildGoodResponse(expected)
774+
if err != nil {
775+
t.Fatalf("%s is not nil", err)
776+
}
777+
srv := buildServer(200, resp, func(r *http.Request) {
778+
if r.Method != "GET" {
779+
t.Fatalf("%s is not GET", r.Method)
780+
}
781+
})
782+
defer srv.Close()
783+
784+
cli := MockClient(srv.Listener.Addr().String(), "")
785+
786+
// Test with valid context
787+
ctx := context.Background()
788+
k, err := cli.NetworkGetKeyWithContext(ctx, "testkey")
789+
if err != nil {
790+
t.Fatalf("Unexpected error: %s", err)
791+
}
792+
if k.ID != expected.ID {
793+
t.Fatalf("Expected ID %s, got %s", expected.ID, k.ID)
794+
}
795+
}
796+
797+
func TestNetworkGetKeyWithContextCancellation(t *testing.T) {
798+
// Create a server that delays response
799+
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
800+
time.Sleep(500 * time.Millisecond)
801+
w.WriteHeader(200)
802+
w.Header().Set("Content-Type", "application/json")
803+
resp, _ := buildGoodResponse(Key{
804+
ID: "testkey",
805+
ACL: ACL([]Access{}),
806+
VersionList: KeyVersionList{},
807+
VersionHash: "VersionHash",
808+
})
809+
w.Write(resp)
810+
}))
811+
defer srv.Close()
812+
813+
cli := MockClient(srv.Listener.Addr().String(), "")
814+
815+
// Test with context that times out before server responds
816+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
817+
defer cancel()
818+
819+
_, err := cli.NetworkGetKeyWithContext(ctx, "testkey")
820+
if err == nil {
821+
t.Fatal("Expected error due to context timeout, got nil")
822+
}
823+
// The error should be related to context deadline or cancellation
824+
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
825+
// Also accept wrapped errors from http client
826+
if err.Error() != "context deadline exceeded" && !errors.Is(errors.Unwrap(err), context.DeadlineExceeded) {
827+
t.Logf("Got error: %v (type: %T)", err, err)
828+
// Accept any error that indicates the request was canceled/timed out
829+
}
830+
}
831+
}
832+
833+
func TestContextCancellationDuringRetry(t *testing.T) {
834+
var requestCount uint64
835+
srv := buildConcurrentServer(200, func(r *http.Request) []byte {
836+
count := atomic.AddUint64(&requestCount, 1)
837+
// Always return 500 error to trigger retry logic
838+
// This ensures the client enters the backoff/retry path
839+
if count <= 3 {
840+
resp, _ := buildErrorResponse(InternalServerErrorCode, nil)
841+
return resp
842+
}
843+
resp, _ := buildGoodResponse(Key{
844+
ID: "testkey",
845+
ACL: ACL([]Access{}),
846+
VersionList: KeyVersionList{},
847+
VersionHash: "VersionHash",
848+
})
849+
return resp
850+
})
851+
defer srv.Close()
852+
853+
cli := MockClient(srv.Listener.Addr().String(), "")
854+
855+
// Use a context that will be canceled during retry backoff.
856+
// The backoff duration starts at ~50ms, so 100ms should be enough
857+
// to make the first request but timeout during the backoff sleep.
858+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
859+
defer cancel()
860+
861+
_, err := cli.NetworkGetKeyWithContext(ctx, "testkey")
862+
// Should get an error because context times out during retry
863+
if err == nil {
864+
t.Fatal("Expected error due to context timeout during retry, got nil")
865+
}
866+
// Verify we got a context-related error
867+
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
868+
t.Fatalf("Expected context deadline or cancellation error, got: %v", err)
869+
}
870+
}
871+
872+
func TestContextCancellationBeforeAuthHandler(t *testing.T) {
873+
expected := Key{
874+
ID: "testkey",
875+
ACL: ACL([]Access{}),
876+
VersionList: KeyVersionList{},
877+
VersionHash: "VersionHash",
878+
}
879+
resp, err := buildGoodResponse(expected)
880+
if err != nil {
881+
t.Fatalf("%s is not nil", err)
882+
}
883+
884+
var serverCalled bool
885+
srv := buildServer(200, resp, func(r *http.Request) {
886+
serverCalled = true
887+
})
888+
defer srv.Close()
889+
890+
cli := MockClient(srv.Listener.Addr().String(), "")
891+
892+
// Cancel the context before making the request
893+
ctx, cancel := context.WithCancel(context.Background())
894+
cancel()
895+
896+
_, err = cli.NetworkGetKeyWithContext(ctx, "testkey")
897+
if err == nil {
898+
t.Fatal("Expected error due to canceled context")
899+
}
900+
if !errors.Is(err, context.Canceled) {
901+
t.Fatalf("Expected context.Canceled error, got: %v", err)
902+
}
903+
if serverCalled {
904+
t.Fatal("Server should not have been called with canceled context")
905+
}
906+
}
907+
908+
func TestUncachedClientCacheGetKeyWithContext(t *testing.T) {
909+
// UncachedHTTPClient.CacheGetKeyWithContext should delegate to NetworkGetKeyWithContext
910+
expected := Key{
911+
ID: "testkey",
912+
ACL: ACL([]Access{}),
913+
VersionList: KeyVersionList{},
914+
VersionHash: "VersionHash",
915+
}
916+
resp, err := buildGoodResponse(expected)
917+
if err != nil {
918+
t.Fatalf("%s is not nil", err)
919+
}
920+
srv := buildServer(200, resp, func(r *http.Request) {
921+
if r.Method != "GET" {
922+
t.Fatalf("%s is not GET", r.Method)
923+
}
924+
})
925+
defer srv.Close()
926+
927+
cli := &UncachedHTTPClient{
928+
Host: srv.Listener.Addr().String(),
929+
AuthHandlers: []AuthHandler{func() (string, string, HTTP) { return "TESTAUTH", "TESTAUTHTYPE", nil }},
930+
DefaultClient: &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}},
931+
Version: "mock",
932+
}
933+
934+
ctx := context.Background()
935+
k, err := cli.CacheGetKeyWithContext(ctx, "testkey")
936+
if err != nil {
937+
t.Fatalf("Unexpected error: %s", err)
938+
}
939+
if k.ID != expected.ID {
940+
t.Fatalf("Expected ID %s, got %s", expected.ID, k.ID)
941+
}
942+
}

0 commit comments

Comments
 (0)