Skip to content

Commit e22230a

Browse files
committed
Close HTTP connection if DIAL_RSP contains error
1 parent a2f5197 commit e22230a

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

pkg/server/server.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ func (c *ProxyClientConnection) send(pkt *client.Packet) error {
5555
_, err := c.HTTP.Write(pkt.GetData().Data)
5656
return err
5757
} else if pkt.Type == client.PacketType_DIAL_RSP {
58+
if pkt.GetDialResponse().Error != "" {
59+
return c.HTTP.Close()
60+
}
5861
return nil
5962
} else {
6063
return fmt.Errorf("attempt to send via unrecognized connection type %v", pkt.Type)
@@ -463,15 +466,22 @@ func (s *ProxyServer) serveRecvBackend(stream agent.AgentService_ConnectServer,
463466
if client, ok := s.PendingDial.Get(resp.Random); !ok {
464467
klog.Warning("<<< DialResp not recognized; dropped")
465468
} else {
469+
dialErr := false
470+
if resp.Error != "" {
471+
klog.Warningf("<<< DIAL_RSP contains error: %v", resp.Error)
472+
dialErr = true
473+
}
466474
err := client.send(pkt)
467475
s.PendingDial.Remove(resp.Random)
468-
if resp.Error != "" {
469-
klog.Warningf("<<< DIAL_RSP received error: %v", resp.Error)
470-
break
471-
} else if err != nil {
476+
if err != nil {
472477
klog.Warningf("<<< DIAL_RSP send to client stream error: %v", err)
478+
dialErr = true
479+
}
480+
// Avoid adding the frontend if there was an error dialing the destination
481+
if dialErr == true {
473482
break
474483
}
484+
475485
client.connectID = resp.ConnectID
476486
client.agentID = agentID
477487
s.addFrontend(agentID, resp.ConnectID, client)

tests/proxy_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"net/http/httptest"
1111
"net/url"
12+
"strings"
1213
"testing"
1314
"time"
1415

@@ -92,8 +93,8 @@ func TestBasicProxy_GRPC(t *testing.T) {
9293
}
9394
}
9495

95-
func TestProxyHandleError_GRPC(t *testing.T) {
96-
invalidServer := httptest.NewUnstartedServer(newEchoServer("hello"))
96+
func TestProxyHandleDialError_GRPC(t *testing.T) {
97+
invalidServer := httptest.NewServer(newEchoServer("hello"))
9798

9899
stopCh := make(chan struct{})
99100
defer close(stopCh)
@@ -121,8 +122,11 @@ func TestProxyHandleError_GRPC(t *testing.T) {
121122
},
122123
}
123124

124-
_, err = c.Get(invalidServer.URL)
125-
if err == nil {
125+
url := invalidServer.URL
126+
invalidServer.Close()
127+
128+
_, err = c.Get(url)
129+
if err == nil || !strings.Contains(err.Error(), "connection refused") {
126130
t.Error("Expected error when destination is unreachable, did not receive error")
127131
}
128132
}

0 commit comments

Comments
 (0)