@@ -2,6 +2,7 @@ package knox
22
33import (
44 "bytes"
5+ "context"
56 "crypto/tls"
67 "encoding/json"
78 "errors"
@@ -15,6 +16,7 @@ import (
1516 "strings"
1617 "sync/atomic"
1718 "testing"
19+ "time"
1820)
1921
2022type mockHTTPClient struct {
@@ -717,3 +719,224 @@ func TestNewFileClient(t *testing.T) {
717719 t .Fatalf ("Expected error starting with '%s', got: %v" , expectedPrefix , err )
718720 }
719721}
722+
723+ func TestCacheGetKeyWithContext (t * testing.T ) {
724+ expected := Key {
725+ ID : "testkey" ,
726+ ACL : ACL ([]Access {}),
727+ VersionList : KeyVersionList {},
728+ VersionHash : "VersionHash" ,
729+ }
730+
731+ keyBytes , err := json .Marshal (expected )
732+ if err != nil {
733+ t .Fatalf ("Error marshalling key: %s" , err )
734+ }
735+
736+ tempDir := t .TempDir ()
737+ err = os .WriteFile (path .Join (tempDir , "testkey" ), keyBytes , 0600 )
738+ if err != nil {
739+ t .Fatalf ("Failed to write test key: %s" , err )
740+ }
741+
742+ cli := & HTTPClient {
743+ KeyFolder : tempDir ,
744+ UncachedClient : & UncachedHTTPClient {},
745+ }
746+
747+ // Test with valid context
748+ ctx := context .Background ()
749+ k , err := cli .CacheGetKeyWithContext (ctx , "testkey" )
750+ if err != nil {
751+ t .Fatalf ("Unexpected error: %s" , err )
752+ }
753+ if k .ID != expected .ID {
754+ t .Fatalf ("Expected ID %s, got %s" , expected .ID , k .ID )
755+ }
756+
757+ // Test with canceled context
758+ canceledCtx , cancel := context .WithCancel (context .Background ())
759+ cancel ()
760+ _ , err = cli .CacheGetKeyWithContext (canceledCtx , "testkey" )
761+ if ! errors .Is (err , context .Canceled ) {
762+ t .Fatalf ("Expected context.Canceled error, got: %v" , err )
763+ }
764+ }
765+
766+ func TestNetworkGetKeyWithContext (t * testing.T ) {
767+ expected := Key {
768+ ID : "testkey" ,
769+ ACL : ACL ([]Access {}),
770+ VersionList : KeyVersionList {},
771+ VersionHash : "VersionHash" ,
772+ }
773+ resp , err := buildGoodResponse (expected )
774+ if err != nil {
775+ t .Fatalf ("%s is not nil" , err )
776+ }
777+ srv := buildServer (200 , resp , func (r * http.Request ) {
778+ if r .Method != "GET" {
779+ t .Fatalf ("%s is not GET" , r .Method )
780+ }
781+ })
782+ defer srv .Close ()
783+
784+ cli := MockClient (srv .Listener .Addr ().String (), "" )
785+
786+ // Test with valid context
787+ ctx := context .Background ()
788+ k , err := cli .NetworkGetKeyWithContext (ctx , "testkey" )
789+ if err != nil {
790+ t .Fatalf ("Unexpected error: %s" , err )
791+ }
792+ if k .ID != expected .ID {
793+ t .Fatalf ("Expected ID %s, got %s" , expected .ID , k .ID )
794+ }
795+ }
796+
797+ func TestNetworkGetKeyWithContextCancellation (t * testing.T ) {
798+ // Create a server that delays response
799+ srv := httptest .NewTLSServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
800+ time .Sleep (500 * time .Millisecond )
801+ w .WriteHeader (200 )
802+ w .Header ().Set ("Content-Type" , "application/json" )
803+ resp , _ := buildGoodResponse (Key {
804+ ID : "testkey" ,
805+ ACL : ACL ([]Access {}),
806+ VersionList : KeyVersionList {},
807+ VersionHash : "VersionHash" ,
808+ })
809+ w .Write (resp )
810+ }))
811+ defer srv .Close ()
812+
813+ cli := MockClient (srv .Listener .Addr ().String (), "" )
814+
815+ // Test with context that times out before server responds
816+ ctx , cancel := context .WithTimeout (context .Background (), 50 * time .Millisecond )
817+ defer cancel ()
818+
819+ _ , err := cli .NetworkGetKeyWithContext (ctx , "testkey" )
820+ if err == nil {
821+ t .Fatal ("Expected error due to context timeout, got nil" )
822+ }
823+ // The error should be related to context deadline or cancellation
824+ if ! errors .Is (err , context .DeadlineExceeded ) && ! errors .Is (err , context .Canceled ) {
825+ // Also accept wrapped errors from http client
826+ if err .Error () != "context deadline exceeded" && ! errors .Is (errors .Unwrap (err ), context .DeadlineExceeded ) {
827+ t .Logf ("Got error: %v (type: %T)" , err , err )
828+ // Accept any error that indicates the request was canceled/timed out
829+ }
830+ }
831+ }
832+
833+ func TestContextCancellationDuringRetry (t * testing.T ) {
834+ var requestCount uint64
835+ srv := buildConcurrentServer (200 , func (r * http.Request ) []byte {
836+ count := atomic .AddUint64 (& requestCount , 1 )
837+ // Always return 500 error to trigger retry logic
838+ // This ensures the client enters the backoff/retry path
839+ if count <= 3 {
840+ resp , _ := buildErrorResponse (InternalServerErrorCode , nil )
841+ return resp
842+ }
843+ resp , _ := buildGoodResponse (Key {
844+ ID : "testkey" ,
845+ ACL : ACL ([]Access {}),
846+ VersionList : KeyVersionList {},
847+ VersionHash : "VersionHash" ,
848+ })
849+ return resp
850+ })
851+ defer srv .Close ()
852+
853+ cli := MockClient (srv .Listener .Addr ().String (), "" )
854+
855+ // Use a context that will be canceled during retry backoff.
856+ // The backoff duration starts at ~50ms, so 100ms should be enough
857+ // to make the first request but timeout during the backoff sleep.
858+ ctx , cancel := context .WithTimeout (context .Background (), 100 * time .Millisecond )
859+ defer cancel ()
860+
861+ _ , err := cli .NetworkGetKeyWithContext (ctx , "testkey" )
862+ // Should get an error because context times out during retry
863+ if err == nil {
864+ t .Fatal ("Expected error due to context timeout during retry, got nil" )
865+ }
866+ // Verify we got a context-related error
867+ if ! errors .Is (err , context .DeadlineExceeded ) && ! errors .Is (err , context .Canceled ) {
868+ t .Fatalf ("Expected context deadline or cancellation error, got: %v" , err )
869+ }
870+ }
871+
872+ func TestContextCancellationBeforeAuthHandler (t * testing.T ) {
873+ expected := Key {
874+ ID : "testkey" ,
875+ ACL : ACL ([]Access {}),
876+ VersionList : KeyVersionList {},
877+ VersionHash : "VersionHash" ,
878+ }
879+ resp , err := buildGoodResponse (expected )
880+ if err != nil {
881+ t .Fatalf ("%s is not nil" , err )
882+ }
883+
884+ var serverCalled bool
885+ srv := buildServer (200 , resp , func (r * http.Request ) {
886+ serverCalled = true
887+ })
888+ defer srv .Close ()
889+
890+ cli := MockClient (srv .Listener .Addr ().String (), "" )
891+
892+ // Cancel the context before making the request
893+ ctx , cancel := context .WithCancel (context .Background ())
894+ cancel ()
895+
896+ _ , err = cli .NetworkGetKeyWithContext (ctx , "testkey" )
897+ if err == nil {
898+ t .Fatal ("Expected error due to canceled context" )
899+ }
900+ if ! errors .Is (err , context .Canceled ) {
901+ t .Fatalf ("Expected context.Canceled error, got: %v" , err )
902+ }
903+ if serverCalled {
904+ t .Fatal ("Server should not have been called with canceled context" )
905+ }
906+ }
907+
908+ func TestUncachedClientCacheGetKeyWithContext (t * testing.T ) {
909+ // UncachedHTTPClient.CacheGetKeyWithContext should delegate to NetworkGetKeyWithContext
910+ expected := Key {
911+ ID : "testkey" ,
912+ ACL : ACL ([]Access {}),
913+ VersionList : KeyVersionList {},
914+ VersionHash : "VersionHash" ,
915+ }
916+ resp , err := buildGoodResponse (expected )
917+ if err != nil {
918+ t .Fatalf ("%s is not nil" , err )
919+ }
920+ srv := buildServer (200 , resp , func (r * http.Request ) {
921+ if r .Method != "GET" {
922+ t .Fatalf ("%s is not GET" , r .Method )
923+ }
924+ })
925+ defer srv .Close ()
926+
927+ cli := & UncachedHTTPClient {
928+ Host : srv .Listener .Addr ().String (),
929+ AuthHandlers : []AuthHandler {func () (string , string , HTTP ) { return "TESTAUTH" , "TESTAUTHTYPE" , nil }},
930+ DefaultClient : & http.Client {Transport : & http.Transport {TLSClientConfig : & tls.Config {InsecureSkipVerify : true }}},
931+ Version : "mock" ,
932+ }
933+
934+ ctx := context .Background ()
935+ k , err := cli .CacheGetKeyWithContext (ctx , "testkey" )
936+ if err != nil {
937+ t .Fatalf ("Unexpected error: %s" , err )
938+ }
939+ if k .ID != expected .ID {
940+ t .Fatalf ("Expected ID %s, got %s" , expected .ID , k .ID )
941+ }
942+ }
0 commit comments