diff --git a/pkg/agent/client.go b/pkg/agent/client.go index 5a19ec9ee..1004bd934 100644 --- a/pkg/agent/client.go +++ b/pkg/agent/client.go @@ -18,6 +18,7 @@ package agent import ( "context" + "errors" "fmt" "io" "net" @@ -151,6 +152,9 @@ type Client struct { serviceAccountTokenPath string warnOnChannelLimit bool + + // Here for testing + readBlockInterval time.Duration } func newAgentClient(address, agentID, agentIdentifiers string, cs *ClientSet, opts ...grpc.DialOption) (*Client, int, error) { @@ -166,6 +170,7 @@ func newAgentClient(address, agentID, agentIdentifiers string, cs *ClientSet, op serviceAccountTokenPath: cs.serviceAccountTokenPath, connManager: newConnectionManager(), warnOnChannelLimit: cs.warnOnChannelLimit, + readBlockInterval: 15 * time.Second, } serverCount, err := a.Connect() if err != nil { @@ -538,10 +543,19 @@ func (a *Client) remoteToProxy(connID int64, eConn *endpointConn) { } for { + select { + case <-a.stopCh: + return + default: + } + timeout := time.Now().Add(a.readBlockInterval) + eConn.conn.SetReadDeadline(timeout) n, err := eConn.conn.Read(buf[:]) klog.V(5).InfoS("received data from remote", "bytes", n, "connectionID", connID) - if err == io.EOF { + if errors.Is(err, os.ErrDeadlineExceeded) { + continue + } else if err == io.EOF { klog.V(2).InfoS("remote connection EOF", "connectionID", connID) return } else if err != nil { diff --git a/pkg/agent/client_test.go b/pkg/agent/client_test.go index 3cd2bf253..1ec6022d6 100644 --- a/pkg/agent/client_test.go +++ b/pkg/agent/client_test.go @@ -42,9 +42,10 @@ func TestServeData_HTTP(t *testing.T) { stopCh: stopCh, } testClient := &Client{ - connManager: newConnectionManager(), - stopCh: stopCh, - cs: cs, + connManager: newConnectionManager(), + stopCh: stopCh, + cs: cs, + readBlockInterval: 15 * time.Second, } testClient.stream, stream = pipe() @@ -133,7 +134,8 @@ func TestServeData_HTTP(t *testing.T) { waitForConnectionDeletion(t, testClient, connID) } -func TestClose_Client(t *testing.T) { +func TestDelayedServedData_HTTP(t *testing.T) { + var err error var stream agent.AgentService_ConnectClient stopCh := make(chan struct{}) cs := &ClientSet{ @@ -144,6 +146,113 @@ func TestClose_Client(t *testing.T) { connManager: newConnectionManager(), stopCh: stopCh, cs: cs, + // Set the readBlockInterval to a short value to check if + // the agent can handle the SetReadDeadline. + readBlockInterval: 1 * time.Second, + } + testClient.stream, stream = pipe() + + // Start agent + go testClient.Serve() + defer close(stopCh) + + // Start test http server as remote service + expectedBody := "Hello, client" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + time.Sleep(4 * time.Second) // HTTPTest times out after 5 seconds. + fmt.Fprint(w, expectedBody) + })) + defer ts.Close() + + // Simulate sending KAS DIAL_REQ to (Agent) Client + dialPacket := newDialPacket("tcp", ts.URL[len("http://"):], 111) + err = stream.Send(dialPacket) + if err != nil { + t.Fatal(err.Error()) + } + + // Expect receiving DIAL_RSP packet from (Agent) Client + pkt, err := stream.Recv() + if err != nil { + t.Fatal(err.Error()) + } + if pkt == nil { + t.Fatal("unexpected nil packet") + } + if pkt.Type != client.PacketType_DIAL_RSP { + t.Errorf("expect PacketType_DIAL_RSP; got %v", pkt.Type) + } + dialRsp := pkt.Payload.(*client.Packet_DialResponse) + connID := dialRsp.DialResponse.ConnectID + if dialRsp.DialResponse.Random != 111 { + t.Errorf("expect random=111; got %v", dialRsp.DialResponse.Random) + } + + // Send Data (HTTP Request) via (Agent) Client to the test http server + t.Logf("Sending data packet at %v", time.Now()) + dataPacket := newDataPacket(connID, []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")) + err = stream.Send(dataPacket) + if err != nil { + t.Error(err.Error()) + } + t.Logf("Sent data packet at %v", time.Now()) + + // Expect receiving http response via (Agent) Client + t.Logf("Receiving http response at %v", time.Now()) + pkt, _ = stream.Recv() + if pkt == nil { + t.Fatalf("unexpected nil packet at %v", time.Now()) + } + if pkt.Type != client.PacketType_DATA { + t.Errorf("expect PacketType_DATA; got %v", pkt.Type) + } + data := pkt.Payload.(*client.Packet_Data).Data.Data + + // Verify response data + // + // HTTP/1.1 200 OK\r\n + // Date: Tue, 07 May 2019 06:44:57 GMT\r\n + // Content-Length: 14\r\n + // Content-Type: text/plain; charset=utf-8\r\n + // \r\n + // Hello, client + headAndBody := strings.Split(string(data), "\r\n") + if body := headAndBody[len(headAndBody)-1]; body != expectedBody { + t.Errorf("expect body %v; got %v", expectedBody, body) + } + + // Force close the test server which will cause remote connection gets droped + ts.Close() + + // Verify receiving CLOSE_RSP + pkt, _ = stream.Recv() + if pkt == nil { + t.Fatal("unexpected nil packet") + } + if pkt.Type != client.PacketType_CLOSE_RSP { + t.Errorf("expect PacketType_CLOSE_RSP; got %v", pkt.Type) + } + closeErr := pkt.Payload.(*client.Packet_CloseResponse).CloseResponse.Error + if closeErr != "" { + t.Errorf("expect nil closeErr; got %v", closeErr) + } + + // Verify internal state is consistent + waitForConnectionDeletion(t, testClient, connID) +} + +func TestClose_Client(t *testing.T) { + var stream agent.AgentService_ConnectClient + stopCh := make(chan struct{}) + cs := &ClientSet{ + clients: make(map[string]*Client), + stopCh: stopCh, + } + testClient := &Client{ + connManager: newConnectionManager(), + stopCh: stopCh, + cs: cs, + readBlockInterval: 15 * time.Second, } testClient.stream, stream = pipe() @@ -229,9 +338,10 @@ func TestConnectionMismatch(t *testing.T) { stopCh: stopCh, } testClient := &Client{ - connManager: newConnectionManager(), - stopCh: stopCh, - cs: cs, + connManager: newConnectionManager(), + stopCh: stopCh, + cs: cs, + readBlockInterval: 15 * time.Second, } testClient.stream, stream = pipe() @@ -291,9 +401,10 @@ func TestFailedSend_DialResp_GRPC(t *testing.T) { stopCh: stopCh, } testClient := &Client{ - connManager: newConnectionManager(), - stopCh: stopCh, - cs: cs, + connManager: newConnectionManager(), + stopCh: stopCh, + cs: cs, + readBlockInterval: 15 * time.Second, } defer func() { close(stopCh) @@ -353,10 +464,11 @@ func TestDrain(t *testing.T) { stopCh: stopCh, } testClient := &Client{ - connManager: newConnectionManager(), - drainCh: drainCh, - stopCh: stopCh, - cs: cs, + connManager: newConnectionManager(), + drainCh: drainCh, + stopCh: stopCh, + cs: cs, + readBlockInterval: 15 * time.Second, } testClient.stream, stream = pipe()