Skip to content

Commit 74886ee

Browse files
committed
UPSTREAM: <carry>: add tls support for apiserver http/grpc
make mlpipeline server url scheme configurable add tls handling for PA and ui remove local grpc client tls. Signed-off-by: Humair Khan <[email protected]>
1 parent 9faafad commit 74886ee

File tree

19 files changed

+420
-172
lines changed

19 files changed

+420
-172
lines changed

backend/src/agent/persistence/client/pipeline_client.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,20 @@ func NewPipelineClient(
5555
basePath string,
5656
mlPipelineServiceName string,
5757
mlPipelineServiceHttpPort string,
58-
mlPipelineServiceGRPCPort string) (*PipelineClient, error) {
58+
mlPipelineServiceGRPCPort string,
59+
mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) {
5960
httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort)
6061
grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort)
61-
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress)
62+
scheme := "http"
63+
if mlPipelineServiceTLSEnabled {
64+
scheme = "https"
65+
}
66+
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress, scheme)
6267
if err != nil {
6368
return nil, errors.Wrapf(err,
6469
"Failed to initialize pipeline client. Error: %s", err.Error())
6570
}
66-
connection, err := util.GetRpcConnection(grpcAddress)
71+
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled)
6772
if err != nil {
6873
return nil, errors.Wrapf(err,
6974
"Failed to get RPC connection. Error: %s", err.Error())

backend/src/agent/persistence/main.go

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package main
1616

1717
import (
1818
"flag"
19+
"strconv"
1920
"time"
2021

2122
"github.com/kubeflow/pipelines/backend/src/agent/persistence/client"
@@ -29,21 +30,22 @@ import (
2930
)
3031

3132
var (
32-
masterURL string
33-
kubeconfig string
34-
initializeTimeout time.Duration
35-
timeout time.Duration
36-
mlPipelineAPIServerName string
37-
mlPipelineAPIServerPort string
38-
mlPipelineAPIServerBasePath string
39-
mlPipelineServiceHttpPort string
40-
mlPipelineServiceGRPCPort string
41-
namespace string
42-
ttlSecondsAfterWorkflowFinish int64
43-
numWorker int
44-
clientQPS float64
45-
clientBurst int
46-
saTokenRefreshIntervalInSecs int64
33+
masterURL string
34+
kubeconfig string
35+
initializeTimeout time.Duration
36+
timeout time.Duration
37+
mlPipelineAPIServerName string
38+
mlPipelineAPIServerPort string
39+
mlPipelineAPIServerBasePath string
40+
mlPipelineServiceHttpPort string
41+
mlPipelineServiceGRPCPort string
42+
mlPipelineServiceTLSEnabledStr string
43+
namespace string
44+
ttlSecondsAfterWorkflowFinish int64
45+
numWorker int
46+
clientQPS float64
47+
clientBurst int
48+
saTokenRefreshIntervalInSecs int64
4749
)
4850

4951
const (
@@ -55,6 +57,7 @@ const (
5557
mlPipelineAPIServerNameFlagName = "mlPipelineAPIServerName"
5658
mlPipelineAPIServerHttpPortFlagName = "mlPipelineServiceHttpPort"
5759
mlPipelineAPIServerGRPCPortFlagName = "mlPipelineServiceGRPCPort"
60+
mlPipelineAPIServerTLSEnabled = "mlPipelineServiceTLSEnabled"
5861
namespaceFlagName = "namespace"
5962
ttlSecondsAfterWorkflowFinishFlagName = "ttlSecondsAfterWorkflowFinish"
6063
numWorkerName = "numWorker"
@@ -102,14 +105,20 @@ func main() {
102105
log.Fatalf("Error starting Service Account Token Refresh Ticker due to: %v", err)
103106
}
104107

108+
mlPipelineServiceTLSEnabled, err := strconv.ParseBool(mlPipelineServiceTLSEnabledStr)
109+
if err != nil {
110+
log.Fatalf("Error parsing boolean flag %s, please provide a valid bool value (true/false). %v", mlPipelineAPIServerTLSEnabled, err)
111+
}
112+
105113
pipelineClient, err := client.NewPipelineClient(
106114
initializeTimeout,
107115
timeout,
108116
tokenRefresher,
109117
mlPipelineAPIServerBasePath,
110118
mlPipelineAPIServerName,
111119
mlPipelineServiceHttpPort,
112-
mlPipelineServiceGRPCPort)
120+
mlPipelineServiceGRPCPort,
121+
mlPipelineServiceTLSEnabled)
113122
if err != nil {
114123
log.Fatalf("Error creating ML pipeline API Server client: %v", err)
115124
}
@@ -136,6 +145,7 @@ func init() {
136145
flag.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.")
137146
flag.StringVar(&mlPipelineServiceHttpPort, mlPipelineAPIServerHttpPortFlagName, "8888", "Http Port of the ML pipeline API server.")
138147
flag.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.")
148+
flag.StringVar(&mlPipelineServiceTLSEnabledStr, mlPipelineAPIServerTLSEnabled, "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
139149
flag.StringVar(&mlPipelineAPIServerBasePath, mlPipelineAPIServerBasePathFlagName,
140150
"/apis/v1beta1", "The base path for the ML pipeline API server.")
141151
flag.StringVar(&namespace, namespaceFlagName, "", "The namespace name used for Kubernetes informers to obtain the listers.")

backend/src/apiserver/main.go

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ package main
1616

1717
import (
1818
"context"
19+
"crypto/tls"
1920
"encoding/json"
2021
"flag"
2122
"fmt"
2223
"github.com/kubeflow/pipelines/backend/src/apiserver/client"
24+
"google.golang.org/grpc/credentials"
2325
"io"
2426
"io/ioutil"
2527
"math"
@@ -52,21 +54,49 @@ var (
5254
httpPortFlag = flag.String("httpPortFlag", ":8888", "Http Proxy Port")
5355
configPath = flag.String("config", "", "Path to JSON file containing config")
5456
sampleConfigPath = flag.String("sampleconfig", "", "Path to samples")
57+
tlsCertPath = flag.String("tlsCertPath", "", "Path to the public tls cert.")
58+
tlsCertKeyPath = flag.String("tlsCertKeyPath", "", "Path to the private tls key cert.")
5559
collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.")
5660
)
5761

5862
type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error
5963

64+
func initCerts() (*tls.Config, error) {
65+
if *tlsCertPath == "" && *tlsCertKeyPath == "" {
66+
// User can choose not to provide certs
67+
return nil, nil
68+
} else if *tlsCertPath == "" {
69+
return nil, fmt.Errorf("Missing tlsCertPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
70+
} else if *tlsCertKeyPath == "" {
71+
return nil, fmt.Errorf("Missing tlsCertKeyPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
72+
}
73+
serverCert, err := tls.LoadX509KeyPair(*tlsCertPath, *tlsCertKeyPath)
74+
if err != nil {
75+
return nil, err
76+
}
77+
config := &tls.Config{
78+
Certificates: []tls.Certificate{serverCert},
79+
}
80+
glog.Info("TLS cert key/pair loaded.")
81+
return config, err
82+
}
83+
6084
func main() {
6185
flag.Parse()
6286

6387
initConfig()
6488
clientManager := cm.NewClientManager()
89+
90+
tlsConfig, err := initCerts()
91+
if err != nil {
92+
glog.Fatalf("Failed to parse Cert paths. Err: %v", err)
93+
}
94+
6595
resourceManager := resource.NewResourceManager(
6696
&clientManager,
6797
&resource.ResourceManagerOptions{CollectMetrics: *collectMetricsFlag},
6898
)
69-
err := loadSamples(resourceManager)
99+
err = loadSamples(resourceManager)
70100
if err != nil {
71101
glog.Fatalf("Failed to load samples. Err: %v", err)
72102
}
@@ -78,8 +108,8 @@ func main() {
78108
}
79109
}
80110

81-
go startRpcServer(resourceManager)
82-
startHttpProxy(resourceManager)
111+
go startRpcServer(resourceManager, tlsConfig)
112+
startHttpProxy(resourceManager, tlsConfig)
83113

84114
clientManager.Close()
85115
}
@@ -93,13 +123,25 @@ func grpcCustomMatcher(key string) (string, bool) {
93123
return strings.ToLower(key), false
94124
}
95125

96-
func startRpcServer(resourceManager *resource.ResourceManager) {
97-
glog.Info("Starting RPC server")
126+
func startRpcServer(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {
127+
var s *grpc.Server
128+
if tlsConfig != nil {
129+
glog.Info("Starting RPC server (TLS enabled)")
130+
tlsCredentials := credentials.NewTLS(tlsConfig)
131+
s = grpc.NewServer(
132+
grpc.Creds(tlsCredentials),
133+
grpc.UnaryInterceptor(apiServerInterceptor),
134+
grpc.MaxRecvMsgSize(math.MaxInt32),
135+
)
136+
} else {
137+
glog.Info("Starting RPC server")
138+
s = grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))
139+
}
140+
98141
listener, err := net.Listen("tcp", *rpcPortFlag)
99142
if err != nil {
100143
glog.Fatalf("Failed to start RPC server: %v", err)
101144
}
102-
s := grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))
103145

104146
sharedExperimentServer := server.NewExperimentServer(resourceManager, &server.ExperimentServerOptions{CollectMetrics: *collectMetricsFlag})
105147
sharedPipelineServer := server.NewPipelineServer(
@@ -141,30 +183,29 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
141183
glog.Info("RPC server started")
142184
}
143185

144-
func startHttpProxy(resourceManager *resource.ResourceManager) {
145-
glog.Info("Starting Http Proxy")
186+
func startHttpProxy(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {
146187

147188
ctx := context.Background()
148189
ctx, cancel := context.WithCancel(ctx)
149190
defer cancel()
150191

151192
// Create gRPC HTTP MUX and register services for v1beta1 api.
152193
runtimeMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(grpcCustomMatcher))
153-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
154-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
155-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux)
156-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
157-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux)
158-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux)
159-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux)
160-
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux)
194+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
195+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
196+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux, tlsConfig)
197+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)
198+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux, tlsConfig)
199+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux, tlsConfig)
200+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux, tlsConfig)
201+
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux, tlsConfig)
161202

162203
// Create gRPC HTTP MUX and register services for v2beta1 api.
163-
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
164-
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
165-
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux)
166-
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
167-
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterArtifactServiceHandlerFromEndpoint, "ArtifactService", ctx, runtimeMux)
204+
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
205+
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
206+
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux, tlsConfig)
207+
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)
208+
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterArtifactServiceHandlerFromEndpoint, "ArtifactService", ctx, runtimeMux, tlsConfig)
168209

169210
// Create a top level mux to include both pipeline upload server and gRPC servers.
170211
topMux := mux.NewRouter()
@@ -197,13 +238,35 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
197238
// Register a handler for Prometheus to poll.
198239
topMux.Handle("/metrics", promhttp.Handler())
199240

200-
http.ListenAndServe(*httpPortFlag, topMux)
241+
if tlsConfig != nil {
242+
glog.Info("Starting Https Proxy")
243+
https := http.Server{
244+
TLSConfig: tlsConfig,
245+
Addr: *httpPortFlag,
246+
Handler: topMux,
247+
}
248+
https.ListenAndServeTLS("", "")
249+
} else {
250+
glog.Info("Starting Http Proxy")
251+
http.ListenAndServe(*httpPortFlag, topMux)
252+
}
253+
201254
glog.Info("Http Proxy started")
202255
}
203256

204-
func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux) {
257+
func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux, tlsConfig *tls.Config) {
205258
endpoint := "localhost" + *rpcPortFlag
206-
opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
259+
var opts []grpc.DialOption
260+
if tlsConfig != nil {
261+
// local client connections via http proxy to grpc should not require tls
262+
tlsConfig.InsecureSkipVerify = true
263+
opts = []grpc.DialOption{
264+
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
265+
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
266+
}
267+
} else {
268+
opts = []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
269+
}
207270

208271
if err := handler(ctx, mux, endpoint, opts); err != nil {
209272
glog.Fatalf("Failed to register %v handler: %v", serviceName, err)

backend/src/common/util/service.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
package util
1616

1717
import (
18+
"crypto/tls"
1819
"fmt"
20+
"google.golang.org/grpc/credentials"
21+
"google.golang.org/grpc/credentials/insecure"
1922
"net/http"
2023
"strings"
2124
"time"
@@ -28,9 +31,9 @@ import (
2831
"k8s.io/client-go/tools/clientcmd"
2932
)
3033

31-
func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string) error {
34+
func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string, scheme string) error {
3235
operation := func() error {
33-
response, err := http.Get(fmt.Sprintf("http://%s%s/healthz", apiAddress, basePath))
36+
response, err := http.Get(fmt.Sprintf("%s://%s%s/healthz", scheme, apiAddress, basePath))
3437
if err != nil {
3538
return err
3639
}
@@ -74,8 +77,17 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) (
7477
return clientSet, config, namespace, nil
7578
}
7679

77-
func GetRpcConnection(address string) (*grpc.ClientConn, error) {
78-
conn, err := grpc.Dial(address, grpc.WithInsecure())
80+
func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) {
81+
creds := insecure.NewCredentials()
82+
if tlsEnabled {
83+
config := &tls.Config{}
84+
creds = credentials.NewTLS(config)
85+
}
86+
87+
conn, err := grpc.Dial(
88+
address,
89+
grpc.WithTransportCredentials(creds),
90+
)
7991
if err != nil {
8092
return nil, errors.Wrapf(err, "Failed to create gRPC connection")
8193
}

backend/src/v2/cacheutils/cache.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ package cacheutils
33
import (
44
"context"
55
"crypto/sha256"
6+
"crypto/tls"
67
"encoding/hex"
78
"encoding/json"
89
"fmt"
10+
"google.golang.org/grpc/credentials"
11+
"google.golang.org/grpc/credentials/insecure"
912
"os"
1013

1114
"google.golang.org/grpc"
@@ -111,10 +114,21 @@ type Client struct {
111114
}
112115

113116
// NewClient creates a Client.
114-
func NewClient() (*Client, error) {
117+
func NewClient(mlPipelineServiceTLSEnabled bool) (*Client, error) {
118+
creds := insecure.NewCredentials()
119+
if mlPipelineServiceTLSEnabled {
120+
config := &tls.Config{
121+
InsecureSkipVerify: false,
122+
}
123+
creds = credentials.NewTLS(config)
124+
}
115125
cacheEndPoint := cacheDefaultEndpoint()
116126
glog.Infof("Connecting to cache endpoint %s", cacheEndPoint)
117-
conn, err := grpc.Dial(cacheEndPoint, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), grpc.WithInsecure())
127+
conn, err := grpc.Dial(
128+
cacheEndPoint,
129+
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)),
130+
grpc.WithTransportCredentials(creds),
131+
)
118132
if err != nil {
119133
return nil, fmt.Errorf("metadata.NewClient() failed: %w", err)
120134
}

0 commit comments

Comments
 (0)