Skip to content

Commit a2f5197

Browse files
committed
Handle case where DIAL_RSP contains error
1 parent ff5e07c commit a2f5197

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

pkg/server/server.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -465,14 +465,17 @@ func (s *ProxyServer) serveRecvBackend(stream agent.AgentService_ConnectServer,
465465
} else {
466466
err := client.send(pkt)
467467
s.PendingDial.Remove(resp.Random)
468-
if err != nil {
468+
if resp.Error != "" {
469+
klog.Warningf("<<< DIAL_RSP received error: %v", resp.Error)
470+
break
471+
} else if err != nil {
469472
klog.Warningf("<<< DIAL_RSP send to client stream error: %v", err)
470-
} else {
471-
client.connectID = resp.ConnectID
472-
client.agentID = agentID
473-
s.addFrontend(agentID, resp.ConnectID, client)
474-
close(client.connected)
473+
break
475474
}
475+
client.connectID = resp.ConnectID
476+
client.agentID = agentID
477+
s.addFrontend(agentID, resp.ConnectID, client)
478+
close(client.connected)
476479
}
477480

478481
case client.PacketType_DATA:

tests/proxy_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,41 @@ func TestBasicProxy_GRPC(t *testing.T) {
9292
}
9393
}
9494

95+
func TestProxyHandleError_GRPC(t *testing.T) {
96+
invalidServer := httptest.NewUnstartedServer(newEchoServer("hello"))
97+
98+
stopCh := make(chan struct{})
99+
defer close(stopCh)
100+
101+
proxy, cleanup, err := runGRPCProxyServer()
102+
if err != nil {
103+
t.Fatal(err)
104+
}
105+
defer cleanup()
106+
107+
runAgent(proxy.agent, stopCh)
108+
109+
// Wait for agent to register on proxy server
110+
time.Sleep(time.Second)
111+
112+
// run test client
113+
tunnel, err := client.CreateGrpcTunnel(proxy.front, grpc.WithInsecure())
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
118+
c := &http.Client{
119+
Transport: &http.Transport{
120+
Dial: tunnel.Dial,
121+
},
122+
}
123+
124+
_, err = c.Get(invalidServer.URL)
125+
if err == nil {
126+
t.Error("Expected error when destination is unreachable, did not receive error")
127+
}
128+
}
129+
95130
func TestProxy_LargeResponse(t *testing.T) {
96131
length := 1 << 20 // 1M
97132
chunks := 10

0 commit comments

Comments
 (0)