Skip to content

Commit b33a76e

Browse files
committed
UPSTREAM <carry>: Added support for TLS to scheduled-workflow
Signed-off-by: Helber Belmiro <[email protected]>
1 parent 802272a commit b33a76e

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

backend/src/common/util/service.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,29 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) (
100100
return clientSet, config, namespace, nil
101101
}
102102

103-
func GetRpcConnectionWithTimeout(address string, timeout time.Time) (*grpc.ClientConn, error) {
103+
func GetRpcConnectionWithTimeout(address string, tlsEnabled bool, caCertPath string, timeout time.Time) (*grpc.ClientConn, error) {
104+
creds := insecure.NewCredentials()
105+
if tlsEnabled {
106+
if caCertPath == "" {
107+
return nil, errors.New("CA cert path is empty")
108+
}
109+
110+
caCert, err := os.ReadFile(caCertPath)
111+
if err != nil {
112+
return nil, errors.Wrap(err, "Encountered error when reading CA cert path for creating a metadata client.")
113+
}
114+
caCertPool := x509.NewCertPool()
115+
caCertPool.AppendCertsFromPEM(caCert)
116+
117+
config := &tls.Config{
118+
RootCAs: caCertPool,
119+
}
120+
creds = credentials.NewTLS(config)
121+
}
122+
104123
ctx, _ := context.WithDeadline(context.Background(), timeout)
105124

106-
conn, err := grpc.DialContext(ctx, address, grpc.WithInsecure(), grpc.WithBlock())
125+
conn, err := grpc.DialContext(ctx, address, grpc.WithTransportCredentials(creds), grpc.WithBlock())
107126
if err != nil {
108127
return nil, errors.Wrapf(err, "Failed to create gRPC connection")
109128
}

backend/src/crd/controller/scheduledworkflow/main.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,26 @@ import (
3636
)
3737

3838
var (
39-
logLevel string
40-
masterURL string
41-
kubeconfig string
42-
namespace string
43-
location *time.Location
44-
clientQPS float64
45-
clientBurst int
46-
mlPipelineAPIServerName string
47-
mlPipelineServiceGRPCPort string
39+
logLevel string
40+
masterURL string
41+
kubeconfig string
42+
namespace string
43+
location *time.Location
44+
clientQPS float64
45+
clientBurst int
46+
mlPipelineAPIServerName string
47+
mlPipelineServiceGRPCPort string
48+
mlPipelineServiceTLSEnabled bool
49+
mlPipelineServiceTLSCert string
4850
)
4951

5052
const (
5153
// These flags match the persistence agent
5254
mlPipelineAPIServerBasePathFlagName = "mlPipelineAPIServerBasePath"
5355
mlPipelineAPIServerNameFlagName = "mlPipelineAPIServerName"
5456
mlPipelineAPIServerGRPCPortFlagName = "mlPipelineServiceGRPCPort"
57+
mlPipelineServiceTLSEnabledFlagName = "mlPipelineServiceTLSEnabled"
58+
mlPipelineServiceTLSCertFlagName = "mlPipelineServiceTLSCert"
5559
apiTokenFile = "/var/run/secrets/kubeflow/tokens/scheduledworkflow-sa-token"
5660
)
5761

@@ -102,7 +106,7 @@ func main() {
102106
grpcAddress := fmt.Sprintf("%s:%s", mlPipelineAPIServerName, mlPipelineServiceGRPCPort)
103107

104108
log.Infof("Connecting the API server over GRPC at: %s", grpcAddress)
105-
apiConnection, err := commonutil.GetRpcConnectionWithTimeout(grpcAddress, time.Now().Add(time.Minute))
109+
apiConnection, err := commonutil.GetRpcConnectionWithTimeout(grpcAddress, mlPipelineServiceTLSEnabled, mlPipelineServiceTLSCert, time.Now().Add(time.Minute))
106110
if err != nil {
107111
log.Fatalf("Error connecting to the API server after trying for one minute: %v", err)
108112
}
@@ -160,6 +164,8 @@ func init() {
160164
flag.Float64Var(&clientQPS, "clientQPS", 5, "The maximum QPS to the master from this client.")
161165
flag.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.")
162166
flag.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.")
167+
flag.BoolVar(&mlPipelineServiceTLSEnabled, mlPipelineServiceTLSEnabledFlagName, false, "TLS enabled in the ML pipeline API server.")
168+
flag.StringVar(&mlPipelineServiceTLSCert, mlPipelineServiceTLSCertFlagName, "", "CA cert to connect to the ML pipeline API server.")
163169
flag.IntVar(&clientBurst, "clientBurst", 10, "Maximum burst for throttle from this client.")
164170
var err error
165171
location, err = util.GetLocation()

0 commit comments

Comments
 (0)