Skip to content

Commit 2260a7e

Browse files
authored
Merge pull request #45 from cheftako/uds
Adding support for UDS.
2 parents fea45dc + e5412e0 commit 2260a7e

File tree

2 files changed

+215
-21
lines changed

2 files changed

+215
-21
lines changed

cmd/client/main.go

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"net/http"
2929
"os"
3030
"os/signal"
31+
"time"
3132

3233
"github.com/spf13/cobra"
3334
"github.com/spf13/pflag"
@@ -68,6 +69,7 @@ type GrpcProxyClientOptions struct {
6869
requestPort int
6970
proxyHost string
7071
proxyPort int
72+
proxyUdsName string
7173
mode string
7274
}
7375

@@ -82,6 +84,7 @@ func (o *GrpcProxyClientOptions) Flags() *pflag.FlagSet {
8284
flags.IntVar(&o.requestPort, "request-port", o.requestPort, "The port the request server is listening on.")
8385
flags.StringVar(&o.proxyHost, "proxy-host", o.proxyHost, "The host of the proxy server.")
8486
flags.IntVar(&o.proxyPort, "proxy-port", o.proxyPort, "The port the proxy server is listening on.")
87+
flags.StringVar(&o.proxyUdsName, "proxy-uds", o.proxyHost, "The UDS name to connect to.")
8588
flags.StringVar(&o.mode, "mode", o.mode, "Mode can be either 'grpc' or 'http-connect'.")
8689

8790
return flags
@@ -97,6 +100,7 @@ func (o *GrpcProxyClientOptions) Print() {
97100
klog.Warningf("RequestPort set to %d.\n", o.requestPort)
98101
klog.Warningf("ProxyHost set to %q.\n", o.proxyHost)
99102
klog.Warningf("ProxyPort set to %d.\n", o.proxyPort)
103+
klog.Warningf("ProxyUdsName set to %q.\n", o.proxyUdsName)
100104
klog.Warningf("Mode set to %q.\n", o.mode)
101105
}
102106

@@ -134,9 +138,18 @@ func (o *GrpcProxyClientOptions) Validate() error {
134138
if o.proxyPort > 49151 {
135139
return fmt.Errorf("please do not try to use ephemeral port %d for the proxy server port", o.proxyPort)
136140
}
137-
if o.proxyPort < 1024 {
141+
if o.proxyPort < 1024 && o.proxyUdsName == "" {
138142
return fmt.Errorf("please do not try to use reserved port %d for the proxy server port", o.proxyPort)
139143
}
144+
if o.proxyUdsName != "" {
145+
if o.proxyPort != 0 {
146+
return fmt.Errorf("please do set proxy server port to 0 not %d when using UDS", o.proxyPort)
147+
}
148+
if o.clientKey != "" || o.clientCert != "" || o.caCert != "" {
149+
return fmt.Errorf("please do set cert materials when using UDS, key = %s, cert = %s, CA = %s",
150+
o.clientKey, o.clientCert, o.caCert)
151+
}
152+
}
140153
return nil
141154
}
142155

@@ -151,6 +164,7 @@ func newGrpcProxyClientOptions() *GrpcProxyClientOptions {
151164
requestPort: 8000,
152165
proxyHost: "localhost",
153166
proxyPort: 8090,
167+
proxyUdsName: "",
154168
mode: "grpc",
155169
}
156170
return &o
@@ -211,6 +225,86 @@ func (c *Client) run(o *GrpcProxyClientOptions) error {
211225
}
212226

213227
func (c *Client) getDialer(o *GrpcProxyClientOptions) (func(ctx context.Context, network, addr string) (net.Conn, error), error) {
228+
if o.proxyUdsName != "" {
229+
return c.getUDSDialer(o)
230+
}
231+
return c.getMTLSDialer(o)
232+
}
233+
234+
func (c *Client) getUDSDialer(o *GrpcProxyClientOptions) (func(ctx context.Context, network, addr string) (net.Conn, error), error) {
235+
var proxyConn net.Conn
236+
var err error
237+
238+
// Setup signal handler
239+
ch := make(chan os.Signal, 1)
240+
signal.Notify(ch)
241+
242+
go func() {
243+
<-ch
244+
if proxyConn != nil {
245+
err := proxyConn.Close()
246+
klog.Infof("connection closed: %v", err)
247+
}
248+
}()
249+
250+
switch o.mode {
251+
case "grpc":
252+
dialOption := grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
253+
// Ignoring addr and timeout arguments:
254+
// addr - comes from the closure
255+
// timeout - is turned off as this is test code and eases debugging.
256+
c, err := net.DialTimeout("unix", o.proxyUdsName, 0)
257+
if err != nil {
258+
klog.Errorf("failed to create connection to uds name %s, error: %v", o.proxyUdsName, err)
259+
}
260+
return c, err
261+
})
262+
tunnel, err := client.CreateGrpcTunnel(o.proxyUdsName, dialOption, grpc.WithInsecure())
263+
if err != nil {
264+
return nil, fmt.Errorf("failed to create tunnel %s, got %v", o.proxyUdsName, err)
265+
}
266+
267+
requestAddress := fmt.Sprintf("%s:%d", o.requestHost, o.requestPort)
268+
proxyConn, err = tunnel.Dial("tcp", requestAddress)
269+
if err != nil {
270+
return nil, fmt.Errorf("failed to dial request %s, got %v", requestAddress, err)
271+
}
272+
case "http-connect":
273+
requestAddress := fmt.Sprintf("%s:%d", o.requestHost, o.requestPort)
274+
275+
proxyConn, err = net.Dial("unix", o.proxyUdsName)
276+
if err != nil {
277+
return nil, fmt.Errorf("dialing proxy %q failed: %v", o.proxyUdsName, err)
278+
}
279+
fmt.Fprintf(proxyConn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", requestAddress, "127.0.0.1")
280+
br := bufio.NewReader(proxyConn)
281+
res, err := http.ReadResponse(br, nil)
282+
if err != nil {
283+
return nil, fmt.Errorf("reading HTTP response from CONNECT to %s via uds proxy %s failed: %v",
284+
requestAddress, o.proxyUdsName, err)
285+
}
286+
if res.StatusCode != 200 {
287+
return nil, fmt.Errorf("proxy error from %s while dialing %s: %v", o.proxyUdsName, requestAddress, res.Status)
288+
}
289+
290+
// It's safe to discard the bufio.Reader here and return the
291+
// original TCP conn directly because we only use this for
292+
// TLS, and in TLS the client speaks first, so we know there's
293+
// no unbuffered data. But we can double-check.
294+
if br.Buffered() > 0 {
295+
return nil, fmt.Errorf("unexpected %d bytes of buffered data from CONNECT uds proxy %q",
296+
br.Buffered(), o.proxyUdsName)
297+
}
298+
default:
299+
return nil, fmt.Errorf("failed to process mode %s", o.mode)
300+
}
301+
302+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
303+
return proxyConn, nil
304+
}, nil
305+
}
306+
307+
func (c *Client) getMTLSDialer(o *GrpcProxyClientOptions) (func(ctx context.Context, network, addr string) (net.Conn, error), error) {
214308
clientCert, err := tls.LoadX509KeyPair(o.clientCert, o.clientKey)
215309
if err != nil {
216310
return nil, fmt.Errorf("failed to read key pair %s & %s, got %v", o.clientCert, o.clientKey, err)

0 commit comments

Comments
 (0)