Skip to content

Commit 5348b54

Browse files
committed
Fix nil pointer dereference in TCPTransport.DoRaw
Signed-off-by: robert-cronin <robert.owen.cronin@gmail.com>
1 parent 563976f commit 5348b54

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

transport/tcp.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package transport
22

33
import (
4+
"context"
45
"net"
56
"net/http"
67
)
@@ -14,19 +15,27 @@ func TCPTransport(host string, opts ...ConnectionOption) (*Transport, error) {
1415
}
1516
}
1617

18+
httpTransport := &http.Transport{
19+
DialContext: new(net.Dialer).DialContext,
20+
TLSClientConfig: cfg.TLSConfig,
21+
}
22+
23+
scheme := "http"
24+
if cfg.TLSConfig != nil {
25+
scheme = "https"
26+
}
27+
28+
dial := func(ctx context.Context) (net.Conn, error) {
29+
return httpTransport.DialContext(ctx, "tcp", host)
30+
}
31+
1732
t := &Transport{
18-
scheme: "http",
33+
scheme: scheme,
1934
host: host,
2035
c: &http.Client{
21-
Transport: &http.Transport{
22-
DialContext: new(net.Dialer).DialContext,
23-
TLSClientConfig: cfg.TLSConfig,
24-
},
36+
Transport: httpTransport,
2537
},
26-
}
27-
28-
if cfg.TLSConfig != nil {
29-
t.scheme = "https"
38+
dial: dial,
3039
}
3140

3241
return t, nil

transport/tcp_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package transport
22

33
import (
44
"context"
5+
"crypto/tls"
56
"io"
67
"net/http"
78
"net/http/httptest"
@@ -35,3 +36,47 @@ func TestTCPTransport(t *testing.T) {
3536
assert.NilError(t, err)
3637
assert.Equal(t, string(buf), data)
3738
}
39+
40+
func TestTCPTransportDoRaw(t *testing.T) {
41+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
42+
// for DoRaw test, we expect an upgrade request
43+
if req.Header.Get("Connection") == "Upgrade" {
44+
w.WriteHeader(http.StatusSwitchingProtocols)
45+
} else {
46+
w.WriteHeader(http.StatusNotFound)
47+
}
48+
}))
49+
defer srv.Close()
50+
51+
ctx := context.Background()
52+
53+
u, err := url.Parse(srv.URL)
54+
assert.NilError(t, err)
55+
56+
tr, err := TCPTransport(u.Host)
57+
assert.NilError(t, err)
58+
59+
conn, err := tr.DoRaw(ctx, "POST", "/grpc", WithUpgrade("h2c"))
60+
assert.Assert(t, conn != nil, "expected a connection but got nil")
61+
assert.NilError(t, err)
62+
}
63+
64+
func TestTCPTransportDoRawWithTLS(t *testing.T) {
65+
// test with TLS configuration
66+
tlsOpt := func(cfg *ConnectionConfig) error {
67+
cfg.TLSConfig = &tls.Config{
68+
InsecureSkipVerify: true,
69+
}
70+
return nil
71+
}
72+
73+
tr, err := TCPTransport("localhost:2376", tlsOpt)
74+
assert.NilError(t, err)
75+
76+
ctx := context.Background()
77+
78+
_, err = tr.DoRaw(ctx, "POST", "/grpc", WithUpgrade("h2c"))
79+
// we expect this to fail with connection error since theres no server
80+
// but it should not panic
81+
assert.Assert(t, err != nil, "expected an error but got none")
82+
}

0 commit comments

Comments
 (0)