@@ -16,10 +16,12 @@ package main
16
16
17
17
import (
18
18
"context"
19
+ "crypto/tls"
19
20
"encoding/json"
20
21
"flag"
21
22
"fmt"
22
23
"github.com/kubeflow/pipelines/backend/src/apiserver/client"
24
+ "google.golang.org/grpc/credentials"
23
25
"io"
24
26
"io/ioutil"
25
27
"math"
@@ -52,21 +54,49 @@ var (
52
54
httpPortFlag = flag .String ("httpPortFlag" , ":8888" , "Http Proxy Port" )
53
55
configPath = flag .String ("config" , "" , "Path to JSON file containing config" )
54
56
sampleConfigPath = flag .String ("sampleconfig" , "" , "Path to samples" )
57
+ tlsCertPath = flag .String ("tlsCertPath" , "" , "Path to the public tls cert." )
58
+ tlsCertKeyPath = flag .String ("tlsCertKeyPath" , "" , "Path to the private tls key cert." )
55
59
collectMetricsFlag = flag .Bool ("collectMetricsFlag" , true , "Whether to collect Prometheus metrics in API server." )
56
60
)
57
61
58
62
type RegisterHttpHandlerFromEndpoint func (ctx context.Context , mux * runtime.ServeMux , endpoint string , opts []grpc.DialOption ) error
59
63
64
+ func initCerts () (* tls.Config , error ) {
65
+ if * tlsCertPath == "" && * tlsCertKeyPath == "" {
66
+ // User can choose not to provide certs
67
+ return nil , nil
68
+ } else if * tlsCertPath == "" {
69
+ return nil , fmt .Errorf ("Missing tlsCertPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required." )
70
+ } else if * tlsCertKeyPath == "" {
71
+ return nil , fmt .Errorf ("Missing tlsCertKeyPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required." )
72
+ }
73
+ serverCert , err := tls .LoadX509KeyPair (* tlsCertPath , * tlsCertKeyPath )
74
+ if err != nil {
75
+ return nil , err
76
+ }
77
+ config := & tls.Config {
78
+ Certificates : []tls.Certificate {serverCert },
79
+ }
80
+ glog .Info ("TLS cert key/pair loaded." )
81
+ return config , err
82
+ }
83
+
60
84
func main () {
61
85
flag .Parse ()
62
86
63
87
initConfig ()
64
88
clientManager := cm .NewClientManager ()
89
+
90
+ tlsConfig , err := initCerts ()
91
+ if err != nil {
92
+ glog .Fatalf ("Failed to parse Cert paths. Err: %v" , err )
93
+ }
94
+
65
95
resourceManager := resource .NewResourceManager (
66
96
& clientManager ,
67
97
& resource.ResourceManagerOptions {CollectMetrics : * collectMetricsFlag },
68
98
)
69
- err : = loadSamples (resourceManager )
99
+ err = loadSamples (resourceManager )
70
100
if err != nil {
71
101
glog .Fatalf ("Failed to load samples. Err: %v" , err )
72
102
}
@@ -78,8 +108,8 @@ func main() {
78
108
}
79
109
}
80
110
81
- go startRpcServer (resourceManager )
82
- startHttpProxy (resourceManager )
111
+ go startRpcServer (resourceManager , tlsConfig )
112
+ startHttpProxy (resourceManager , tlsConfig )
83
113
84
114
clientManager .Close ()
85
115
}
@@ -93,13 +123,25 @@ func grpcCustomMatcher(key string) (string, bool) {
93
123
return strings .ToLower (key ), false
94
124
}
95
125
96
- func startRpcServer (resourceManager * resource.ResourceManager ) {
97
- glog .Info ("Starting RPC server" )
126
+ func startRpcServer (resourceManager * resource.ResourceManager , tlsConfig * tls.Config ) {
127
+ var s * grpc.Server
128
+ if tlsConfig != nil {
129
+ glog .Info ("Starting RPC server (TLS enabled)" )
130
+ tlsCredentials := credentials .NewTLS (tlsConfig )
131
+ s = grpc .NewServer (
132
+ grpc .Creds (tlsCredentials ),
133
+ grpc .UnaryInterceptor (apiServerInterceptor ),
134
+ grpc .MaxRecvMsgSize (math .MaxInt32 ),
135
+ )
136
+ } else {
137
+ glog .Info ("Starting RPC server" )
138
+ s = grpc .NewServer (grpc .UnaryInterceptor (apiServerInterceptor ), grpc .MaxRecvMsgSize (math .MaxInt32 ))
139
+ }
140
+
98
141
listener , err := net .Listen ("tcp" , * rpcPortFlag )
99
142
if err != nil {
100
143
glog .Fatalf ("Failed to start RPC server: %v" , err )
101
144
}
102
- s := grpc .NewServer (grpc .UnaryInterceptor (apiServerInterceptor ), grpc .MaxRecvMsgSize (math .MaxInt32 ))
103
145
104
146
sharedExperimentServer := server .NewExperimentServer (resourceManager , & server.ExperimentServerOptions {CollectMetrics : * collectMetricsFlag })
105
147
sharedPipelineServer := server .NewPipelineServer (
@@ -141,30 +183,29 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
141
183
glog .Info ("RPC server started" )
142
184
}
143
185
144
- func startHttpProxy (resourceManager * resource.ResourceManager ) {
145
- glog .Info ("Starting Http Proxy" )
186
+ func startHttpProxy (resourceManager * resource.ResourceManager , tlsConfig * tls.Config ) {
146
187
147
188
ctx := context .Background ()
148
189
ctx , cancel := context .WithCancel (ctx )
149
190
defer cancel ()
150
191
151
192
// Create gRPC HTTP MUX and register services for v1beta1 api.
152
193
runtimeMux := runtime .NewServeMux (runtime .WithIncomingHeaderMatcher (grpcCustomMatcher ))
153
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterPipelineServiceHandlerFromEndpoint , "PipelineService" , ctx , runtimeMux )
154
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterExperimentServiceHandlerFromEndpoint , "ExperimentService" , ctx , runtimeMux )
155
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterJobServiceHandlerFromEndpoint , "JobService" , ctx , runtimeMux )
156
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterRunServiceHandlerFromEndpoint , "RunService" , ctx , runtimeMux )
157
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterTaskServiceHandlerFromEndpoint , "TaskService" , ctx , runtimeMux )
158
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterReportServiceHandlerFromEndpoint , "ReportService" , ctx , runtimeMux )
159
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterVisualizationServiceHandlerFromEndpoint , "Visualization" , ctx , runtimeMux )
160
- registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterAuthServiceHandlerFromEndpoint , "AuthService" , ctx , runtimeMux )
194
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterPipelineServiceHandlerFromEndpoint , "PipelineService" , ctx , runtimeMux , tlsConfig )
195
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterExperimentServiceHandlerFromEndpoint , "ExperimentService" , ctx , runtimeMux , tlsConfig )
196
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterJobServiceHandlerFromEndpoint , "JobService" , ctx , runtimeMux , tlsConfig )
197
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterRunServiceHandlerFromEndpoint , "RunService" , ctx , runtimeMux , tlsConfig )
198
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterTaskServiceHandlerFromEndpoint , "TaskService" , ctx , runtimeMux , tlsConfig )
199
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterReportServiceHandlerFromEndpoint , "ReportService" , ctx , runtimeMux , tlsConfig )
200
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterVisualizationServiceHandlerFromEndpoint , "Visualization" , ctx , runtimeMux , tlsConfig )
201
+ registerHttpHandlerFromEndpoint (apiv1beta1 .RegisterAuthServiceHandlerFromEndpoint , "AuthService" , ctx , runtimeMux , tlsConfig )
161
202
162
203
// Create gRPC HTTP MUX and register services for v2beta1 api.
163
- registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterExperimentServiceHandlerFromEndpoint , "ExperimentService" , ctx , runtimeMux )
164
- registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterPipelineServiceHandlerFromEndpoint , "PipelineService" , ctx , runtimeMux )
165
- registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterRecurringRunServiceHandlerFromEndpoint , "RecurringRunService" , ctx , runtimeMux )
166
- registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterRunServiceHandlerFromEndpoint , "RunService" , ctx , runtimeMux )
167
- registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterArtifactServiceHandlerFromEndpoint , "ArtifactService" , ctx , runtimeMux )
204
+ registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterExperimentServiceHandlerFromEndpoint , "ExperimentService" , ctx , runtimeMux , tlsConfig )
205
+ registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterPipelineServiceHandlerFromEndpoint , "PipelineService" , ctx , runtimeMux , tlsConfig )
206
+ registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterRecurringRunServiceHandlerFromEndpoint , "RecurringRunService" , ctx , runtimeMux , tlsConfig )
207
+ registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterRunServiceHandlerFromEndpoint , "RunService" , ctx , runtimeMux , tlsConfig )
208
+ registerHttpHandlerFromEndpoint (apiv2beta1 .RegisterArtifactServiceHandlerFromEndpoint , "ArtifactService" , ctx , runtimeMux , tlsConfig )
168
209
169
210
// Create a top level mux to include both pipeline upload server and gRPC servers.
170
211
topMux := mux .NewRouter ()
@@ -197,13 +238,35 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
197
238
// Register a handler for Prometheus to poll.
198
239
topMux .Handle ("/metrics" , promhttp .Handler ())
199
240
200
- http .ListenAndServe (* httpPortFlag , topMux )
241
+ if tlsConfig != nil {
242
+ glog .Info ("Starting Https Proxy" )
243
+ https := http.Server {
244
+ TLSConfig : tlsConfig ,
245
+ Addr : * httpPortFlag ,
246
+ Handler : topMux ,
247
+ }
248
+ https .ListenAndServeTLS ("" , "" )
249
+ } else {
250
+ glog .Info ("Starting Http Proxy" )
251
+ http .ListenAndServe (* httpPortFlag , topMux )
252
+ }
253
+
201
254
glog .Info ("Http Proxy started" )
202
255
}
203
256
204
- func registerHttpHandlerFromEndpoint (handler RegisterHttpHandlerFromEndpoint , serviceName string , ctx context.Context , mux * runtime.ServeMux ) {
257
+ func registerHttpHandlerFromEndpoint (handler RegisterHttpHandlerFromEndpoint , serviceName string , ctx context.Context , mux * runtime.ServeMux , tlsConfig * tls. Config ) {
205
258
endpoint := "localhost" + * rpcPortFlag
206
- opts := []grpc.DialOption {grpc .WithInsecure (), grpc .WithDefaultCallOptions (grpc .MaxCallRecvMsgSize (math .MaxInt32 ))}
259
+ var opts []grpc.DialOption
260
+ if tlsConfig != nil {
261
+ // local client connections via http proxy to grpc should not require tls
262
+ tlsConfig .InsecureSkipVerify = true
263
+ opts = []grpc.DialOption {
264
+ grpc .WithTransportCredentials (credentials .NewTLS (tlsConfig )),
265
+ grpc .WithDefaultCallOptions (grpc .MaxCallRecvMsgSize (math .MaxInt32 )),
266
+ }
267
+ } else {
268
+ opts = []grpc.DialOption {grpc .WithInsecure (), grpc .WithDefaultCallOptions (grpc .MaxCallRecvMsgSize (math .MaxInt32 ))}
269
+ }
207
270
208
271
if err := handler (ctx , mux , endpoint , opts ); err != nil {
209
272
glog .Fatalf ("Failed to register %v handler: %v" , serviceName , err )
0 commit comments