Skip to content

Commit e6e9418

Browse files
authored
Merge pull request #94 from Jefftree/dial_rsp
Don't add frontend on proxy server if DIAL_RSP contains error
2 parents 42e9699 + e22230a commit e6e9418

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

pkg/server/server.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ func (c *ProxyClientConnection) send(pkt *client.Packet) error {
5555
_, err := c.HTTP.Write(pkt.GetData().Data)
5656
return err
5757
} else if pkt.Type == client.PacketType_DIAL_RSP {
58+
if pkt.GetDialResponse().Error != "" {
59+
return c.HTTP.Close()
60+
}
5861
return nil
5962
} else {
6063
return fmt.Errorf("attempt to send via unrecognized connection type %v", pkt.Type)
@@ -468,16 +471,26 @@ func (s *ProxyServer) serveRecvBackend(stream agent.AgentService_ConnectServer,
468471
if client, ok := s.PendingDial.Get(resp.Random); !ok {
469472
klog.Warning("<<< DialResp not recognized; dropped")
470473
} else {
474+
dialErr := false
475+
if resp.Error != "" {
476+
klog.Warningf("<<< DIAL_RSP contains error: %v", resp.Error)
477+
dialErr = true
478+
}
471479
err := client.send(pkt)
472480
s.PendingDial.Remove(resp.Random)
473481
if err != nil {
474482
klog.Warningf("<<< DIAL_RSP send to client stream error: %v", err)
475-
} else {
476-
client.connectID = resp.ConnectID
477-
client.agentID = agentID
478-
s.addFrontend(agentID, resp.ConnectID, client)
479-
close(client.connected)
483+
dialErr = true
480484
}
485+
// Avoid adding the frontend if there was an error dialing the destination
486+
if dialErr == true {
487+
break
488+
}
489+
490+
client.connectID = resp.ConnectID
491+
client.agentID = agentID
492+
s.addFrontend(agentID, resp.ConnectID, client)
493+
close(client.connected)
481494
}
482495

483496
case client.PacketType_DATA:

tests/proxy_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"net/http/httptest"
1111
"net/url"
12+
"strings"
1213
"testing"
1314
"time"
1415

@@ -92,6 +93,44 @@ func TestBasicProxy_GRPC(t *testing.T) {
9293
}
9394
}
9495

96+
func TestProxyHandleDialError_GRPC(t *testing.T) {
97+
invalidServer := httptest.NewServer(newEchoServer("hello"))
98+
99+
stopCh := make(chan struct{})
100+
defer close(stopCh)
101+
102+
proxy, cleanup, err := runGRPCProxyServer()
103+
if err != nil {
104+
t.Fatal(err)
105+
}
106+
defer cleanup()
107+
108+
runAgent(proxy.agent, stopCh)
109+
110+
// Wait for agent to register on proxy server
111+
time.Sleep(time.Second)
112+
113+
// run test client
114+
tunnel, err := client.CreateGrpcTunnel(proxy.front, grpc.WithInsecure())
115+
if err != nil {
116+
t.Fatal(err)
117+
}
118+
119+
c := &http.Client{
120+
Transport: &http.Transport{
121+
Dial: tunnel.Dial,
122+
},
123+
}
124+
125+
url := invalidServer.URL
126+
invalidServer.Close()
127+
128+
_, err = c.Get(url)
129+
if err == nil || !strings.Contains(err.Error(), "connection refused") {
130+
t.Error("Expected error when destination is unreachable, did not receive error")
131+
}
132+
}
133+
95134
func TestProxy_LargeResponse(t *testing.T) {
96135
length := 1 << 20 // 1M
97136
chunks := 10

0 commit comments

Comments
 (0)