@@ -45,8 +45,6 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
45
45
grpc .WithDefaultCallOptions (grpc .MaxCallSendMsgSize (defaults .DefaultMaxSendMsgSize )),
46
46
}
47
47
needDialer := true
48
- needWithInsecure := true
49
- tlsServerName := ""
50
48
51
49
var unary []grpc.UnaryClientInterceptor
52
50
var stream []grpc.StreamClientInterceptor
@@ -56,19 +54,17 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
56
54
var tracerDelegate TracerDelegate
57
55
var sessionDialer func (context.Context , string , map [string ][]string ) (net.Conn , error )
58
56
var customDialOptions []grpc.DialOption
57
+ var creds * withCredentials
59
58
60
59
for _ , o := range opts {
61
60
if _ , ok := o .(* withFailFast ); ok {
62
61
gopts = append (gopts , grpc .FailOnNonTempDialError (true ))
63
62
}
64
63
if credInfo , ok := o .(* withCredentials ); ok {
65
- opt , err := loadCredentials (credInfo )
66
- if err != nil {
67
- return nil , err
64
+ if creds == nil {
65
+ creds = & withCredentials {}
68
66
}
69
- gopts = append (gopts , opt )
70
- needWithInsecure = false
71
- tlsServerName = credInfo .ServerName
67
+ creds = creds .merge (credInfo )
72
68
}
73
69
if wt , ok := o .(* withTracer ); ok {
74
70
customTracer = true
@@ -89,6 +85,16 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
89
85
}
90
86
}
91
87
88
+ if creds == nil {
89
+ gopts = append (gopts , grpc .WithTransportCredentials (insecure .NewCredentials ()))
90
+ } else {
91
+ credOpts , err := loadCredentials (creds )
92
+ if err != nil {
93
+ return nil , err
94
+ }
95
+ gopts = append (gopts , credOpts )
96
+ }
97
+
92
98
if ! customTracer {
93
99
if span := trace .SpanFromContext (ctx ); span .SpanContext ().IsValid () {
94
100
tracerProvider = span .TracerProvider ()
@@ -108,9 +114,6 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
108
114
}
109
115
gopts = append (gopts , grpc .WithContextDialer (dialFn ))
110
116
}
111
- if needWithInsecure {
112
- gopts = append (gopts , grpc .WithTransportCredentials (insecure .NewCredentials ()))
113
- }
114
117
if address == "" {
115
118
address = appdefaults .Address
116
119
}
@@ -122,7 +125,10 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
122
125
// ref: https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.3
123
126
// - However, when TLS specified, grpc-go requires it must match
124
127
// with its servername specified for certificate validation.
125
- authority := tlsServerName
128
+ var authority string
129
+ if creds != nil && creds .serverName != "" {
130
+ authority = creds .serverName
131
+ }
126
132
if authority == "" {
127
133
// authority as hostname from target address
128
134
uri , err := url .Parse (address )
@@ -201,47 +207,108 @@ func WithContextDialer(df func(context.Context, string) (net.Conn, error)) Clien
201
207
}
202
208
203
209
type withCredentials struct {
204
- ServerName string
205
- CACert string
206
- Cert string
207
- Key string
210
+ // server options
211
+ serverName string
212
+ caCert string
213
+ caCertSystem bool
214
+
215
+ // client options
216
+ cert string
217
+ key string
218
+ }
219
+
220
+ func (opts * withCredentials ) merge (opts2 * withCredentials ) * withCredentials {
221
+ result := * opts
222
+ if opts2 == nil {
223
+ return & result
224
+ }
225
+
226
+ // server options
227
+ if opts2 .serverName != "" {
228
+ result .serverName = opts2 .serverName
229
+ }
230
+ if opts2 .caCert != "" {
231
+ result .caCert = opts2 .caCert
232
+ }
233
+ if opts2 .caCertSystem {
234
+ result .caCertSystem = opts2 .caCertSystem
235
+ }
236
+
237
+ // client options
238
+ if opts2 .cert != "" {
239
+ result .cert = opts2 .cert
240
+ }
241
+ if opts2 .key != "" {
242
+ result .key = opts2 .key
243
+ }
244
+
245
+ return & result
208
246
}
209
247
210
248
func (* withCredentials ) isClientOpt () {}
211
249
212
250
// WithCredentials configures the TLS parameters of the client.
213
251
// Arguments:
214
- // * serverName: specifies the name of the target server
215
- // * ca: specifies the filepath of the CA certificate to use for verification
216
- // * cert: specifies the filepath of the client certificate
217
- // * key: specifies the filepath of the client key
218
- func WithCredentials (serverName , ca , cert , key string ) ClientOpt {
219
- return & withCredentials {serverName , ca , cert , key }
252
+ // * cert: specifies the filepath of the client certificate
253
+ // * key: specifies the filepath of the client key
254
+ func WithCredentials (cert , key string ) ClientOpt {
255
+ return & withCredentials {
256
+ cert : cert ,
257
+ key : key ,
258
+ }
259
+ }
260
+
261
+ // WithServerConfig configures the TLS parameters to connect to the server.
262
+ // Arguments:
263
+ // * serverName: specifies the server name to verify the hostname
264
+ // * caCert: specifies the filepath of the CA certificate
265
+ func WithServerConfig (serverName , caCert string ) ClientOpt {
266
+ return & withCredentials {
267
+ serverName : serverName ,
268
+ caCert : caCert ,
269
+ }
270
+ }
271
+
272
+ // WithServerConfigSystem configures the TLS parameters to connect to the
273
+ // server, using the system's certificate pool.
274
+ func WithServerConfigSystem (serverName string ) ClientOpt {
275
+ return & withCredentials {
276
+ serverName : serverName ,
277
+ caCertSystem : true ,
278
+ }
220
279
}
221
280
222
281
func loadCredentials (opts * withCredentials ) (grpc.DialOption , error ) {
223
- ca , err := os .ReadFile (opts .CACert )
224
- if err != nil {
225
- return nil , errors .Wrap (err , "could not read ca certificate" )
282
+ cfg := & tls.Config {}
283
+
284
+ if opts .caCertSystem {
285
+ cfg .RootCAs , _ = x509 .SystemCertPool ()
286
+ }
287
+ if cfg .RootCAs == nil {
288
+ cfg .RootCAs = x509 .NewCertPool ()
226
289
}
227
290
228
- certPool := x509 .NewCertPool ()
229
- if ok := certPool .AppendCertsFromPEM (ca ); ! ok {
230
- return nil , errors .New ("failed to append ca certs" )
291
+ if opts .caCert != "" {
292
+ ca , err := os .ReadFile (opts .caCert )
293
+ if err != nil {
294
+ return nil , errors .Wrap (err , "could not read ca certificate" )
295
+ }
296
+ if ok := cfg .RootCAs .AppendCertsFromPEM (ca ); ! ok {
297
+ return nil , errors .New ("failed to append ca certs" )
298
+ }
231
299
}
232
300
233
- cfg := & tls.Config {
234
- ServerName : opts .ServerName ,
235
- RootCAs : certPool ,
301
+ if opts .serverName != "" {
302
+ cfg .ServerName = opts .serverName
236
303
}
237
304
238
305
// we will produce an error if the user forgot about either cert or key if at least one is specified
239
- if opts .Cert != "" || opts .Key != "" {
240
- cert , err := tls .LoadX509KeyPair (opts .Cert , opts .Key )
306
+ if opts .cert != "" || opts .key != "" {
307
+ cert , err := tls .LoadX509KeyPair (opts .cert , opts .key )
241
308
if err != nil {
242
309
return nil , errors .Wrap (err , "could not read certificate/key" )
243
310
}
244
- cfg .Certificates = []tls. Certificate { cert }
311
+ cfg .Certificates = append ( cfg . Certificates , cert )
245
312
}
246
313
247
314
return grpc .WithTransportCredentials (credentials .NewTLS (cfg )), nil
0 commit comments