@@ -28,6 +28,7 @@ import (
28
28
"net/http"
29
29
"os"
30
30
"os/signal"
31
+ "time"
31
32
32
33
"github.com/spf13/cobra"
33
34
"github.com/spf13/pflag"
@@ -68,6 +69,7 @@ type GrpcProxyClientOptions struct {
68
69
requestPort int
69
70
proxyHost string
70
71
proxyPort int
72
+ proxyUdsName string
71
73
mode string
72
74
}
73
75
@@ -82,6 +84,7 @@ func (o *GrpcProxyClientOptions) Flags() *pflag.FlagSet {
82
84
flags .IntVar (& o .requestPort , "request-port" , o .requestPort , "The port the request server is listening on." )
83
85
flags .StringVar (& o .proxyHost , "proxy-host" , o .proxyHost , "The host of the proxy server." )
84
86
flags .IntVar (& o .proxyPort , "proxy-port" , o .proxyPort , "The port the proxy server is listening on." )
87
+ flags .StringVar (& o .proxyUdsName , "proxy-uds" , o .proxyHost , "The UDS name to connect to." )
85
88
flags .StringVar (& o .mode , "mode" , o .mode , "Mode can be either 'grpc' or 'http-connect'." )
86
89
87
90
return flags
@@ -97,6 +100,7 @@ func (o *GrpcProxyClientOptions) Print() {
97
100
klog .Warningf ("RequestPort set to %d.\n " , o .requestPort )
98
101
klog .Warningf ("ProxyHost set to %q.\n " , o .proxyHost )
99
102
klog .Warningf ("ProxyPort set to %d.\n " , o .proxyPort )
103
+ klog .Warningf ("ProxyUdsName set to %q.\n " , o .proxyUdsName )
100
104
klog .Warningf ("Mode set to %q.\n " , o .mode )
101
105
}
102
106
@@ -134,9 +138,18 @@ func (o *GrpcProxyClientOptions) Validate() error {
134
138
if o .proxyPort > 49151 {
135
139
return fmt .Errorf ("please do not try to use ephemeral port %d for the proxy server port" , o .proxyPort )
136
140
}
137
- if o .proxyPort < 1024 {
141
+ if o .proxyPort < 1024 && o . proxyUdsName == "" {
138
142
return fmt .Errorf ("please do not try to use reserved port %d for the proxy server port" , o .proxyPort )
139
143
}
144
+ if o .proxyUdsName != "" {
145
+ if o .proxyPort != 0 {
146
+ return fmt .Errorf ("please do set proxy server port to 0 not %d when using UDS" , o .proxyPort )
147
+ }
148
+ if o .clientKey != "" || o .clientCert != "" || o .caCert != "" {
149
+ return fmt .Errorf ("please do set cert materials when using UDS, key = %s, cert = %s, CA = %s" ,
150
+ o .clientKey , o .clientCert , o .caCert )
151
+ }
152
+ }
140
153
return nil
141
154
}
142
155
@@ -151,6 +164,7 @@ func newGrpcProxyClientOptions() *GrpcProxyClientOptions {
151
164
requestPort : 8000 ,
152
165
proxyHost : "localhost" ,
153
166
proxyPort : 8090 ,
167
+ proxyUdsName : "" ,
154
168
mode : "grpc" ,
155
169
}
156
170
return & o
@@ -211,6 +225,86 @@ func (c *Client) run(o *GrpcProxyClientOptions) error {
211
225
}
212
226
213
227
func (c * Client ) getDialer (o * GrpcProxyClientOptions ) (func (ctx context.Context , network , addr string ) (net.Conn , error ), error ) {
228
+ if o .proxyUdsName != "" {
229
+ return c .getUDSDialer (o )
230
+ }
231
+ return c .getMTLSDialer (o )
232
+ }
233
+
234
+ func (c * Client ) getUDSDialer (o * GrpcProxyClientOptions ) (func (ctx context.Context , network , addr string ) (net.Conn , error ), error ) {
235
+ var proxyConn net.Conn
236
+ var err error
237
+
238
+ // Setup signal handler
239
+ ch := make (chan os.Signal , 1 )
240
+ signal .Notify (ch )
241
+
242
+ go func () {
243
+ <- ch
244
+ if proxyConn != nil {
245
+ err := proxyConn .Close ()
246
+ klog .Infof ("connection closed: %v" , err )
247
+ }
248
+ }()
249
+
250
+ switch o .mode {
251
+ case "grpc" :
252
+ dialOption := grpc .WithDialer (func (string , time.Duration ) (net.Conn , error ) {
253
+ // Ignoring addr and timeout arguments:
254
+ // addr - comes from the closure
255
+ // timeout - is turned off as this is test code and eases debugging.
256
+ c , err := net .DialTimeout ("unix" , o .proxyUdsName , 0 )
257
+ if err != nil {
258
+ klog .Errorf ("failed to create connection to uds name %s, error: %v" , o .proxyUdsName , err )
259
+ }
260
+ return c , err
261
+ })
262
+ tunnel , err := client .CreateGrpcTunnel (o .proxyUdsName , dialOption , grpc .WithInsecure ())
263
+ if err != nil {
264
+ return nil , fmt .Errorf ("failed to create tunnel %s, got %v" , o .proxyUdsName , err )
265
+ }
266
+
267
+ requestAddress := fmt .Sprintf ("%s:%d" , o .requestHost , o .requestPort )
268
+ proxyConn , err = tunnel .Dial ("tcp" , requestAddress )
269
+ if err != nil {
270
+ return nil , fmt .Errorf ("failed to dial request %s, got %v" , requestAddress , err )
271
+ }
272
+ case "http-connect" :
273
+ requestAddress := fmt .Sprintf ("%s:%d" , o .requestHost , o .requestPort )
274
+
275
+ proxyConn , err = net .Dial ("unix" , o .proxyUdsName )
276
+ if err != nil {
277
+ return nil , fmt .Errorf ("dialing proxy %q failed: %v" , o .proxyUdsName , err )
278
+ }
279
+ fmt .Fprintf (proxyConn , "CONNECT %s HTTP/1.1\r \n Host: %s\r \n \r \n " , requestAddress , "127.0.0.1" )
280
+ br := bufio .NewReader (proxyConn )
281
+ res , err := http .ReadResponse (br , nil )
282
+ if err != nil {
283
+ return nil , fmt .Errorf ("reading HTTP response from CONNECT to %s via uds proxy %s failed: %v" ,
284
+ requestAddress , o .proxyUdsName , err )
285
+ }
286
+ if res .StatusCode != 200 {
287
+ return nil , fmt .Errorf ("proxy error from %s while dialing %s: %v" , o .proxyUdsName , requestAddress , res .Status )
288
+ }
289
+
290
+ // It's safe to discard the bufio.Reader here and return the
291
+ // original TCP conn directly because we only use this for
292
+ // TLS, and in TLS the client speaks first, so we know there's
293
+ // no unbuffered data. But we can double-check.
294
+ if br .Buffered () > 0 {
295
+ return nil , fmt .Errorf ("unexpected %d bytes of buffered data from CONNECT uds proxy %q" ,
296
+ br .Buffered (), o .proxyUdsName )
297
+ }
298
+ default :
299
+ return nil , fmt .Errorf ("failed to process mode %s" , o .mode )
300
+ }
301
+
302
+ return func (ctx context.Context , network , addr string ) (net.Conn , error ) {
303
+ return proxyConn , nil
304
+ }, nil
305
+ }
306
+
307
+ func (c * Client ) getMTLSDialer (o * GrpcProxyClientOptions ) (func (ctx context.Context , network , addr string ) (net.Conn , error ), error ) {
214
308
clientCert , err := tls .LoadX509KeyPair (o .clientCert , o .clientKey )
215
309
if err != nil {
216
310
return nil , fmt .Errorf ("failed to read key pair %s & %s, got %v" , o .clientCert , o .clientKey , err )
0 commit comments