Skip to content

Commit bdabd17

Browse files
authored
Merge pull request #103 from hbelmiro/RHOAIENG-13871-stable
UPSTREAM: <carry>: Use DSPA custom ca cert on MLMD and Persistence Agent clients
2 parents 67ff1b5 + 7e7a993 commit bdabd17

File tree

21 files changed

+101
-23
lines changed

21 files changed

+101
-23
lines changed

backend/src/agent/persistence/client/pipeline_client.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ func NewPipelineClient(
5656
mlPipelineServiceName string,
5757
mlPipelineServiceHttpPort string,
5858
mlPipelineServiceGRPCPort string,
59-
mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) {
59+
mlPipelineServiceTLSEnabled bool,
60+
caCertPath string) (*PipelineClient, error) {
6061
httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort)
6162
grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort)
6263
scheme := "http"
@@ -68,7 +69,7 @@ func NewPipelineClient(
6869
return nil, errors.Wrapf(err,
6970
"Failed to initialize pipeline client. Error: %s", err.Error())
7071
}
71-
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled)
72+
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled, caCertPath)
7273
if err != nil {
7374
return nil, errors.Wrapf(err,
7475
"Failed to get RPC connection. Error: %s", err.Error())

backend/src/agent/persistence/main.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ var (
4848
clientBurst int
4949
executionType string
5050
saTokenRefreshIntervalInSecs int64
51+
caCertPath string
5152
)
5253

5354
const (
@@ -68,6 +69,7 @@ const (
6869
clientBurstFlagName = "clientBurst"
6970
executionTypeFlagName = "executionType"
7071
saTokenRefreshIntervalFlagName = "saTokenRefreshIntervalInSecs"
72+
caCertPathFlagName = "caCertPath"
7173
)
7274

7375
const (
@@ -135,7 +137,8 @@ func main() {
135137
mlPipelineAPIServerName,
136138
mlPipelineServiceHttpPort,
137139
mlPipelineServiceGRPCPort,
138-
mlPipelineServiceTLSEnabled)
140+
mlPipelineServiceTLSEnabled,
141+
caCertPath)
139142
if err != nil {
140143
log.Fatalf("Error creating ML pipeline API Server client: %v", err)
141144
}
@@ -177,5 +180,5 @@ func init() {
177180
// TODO use viper/config file instead. Sync `saTokenRefreshIntervalFlagName` with the value from manifest file by using ENV var.
178181
flag.Int64Var(&saTokenRefreshIntervalInSecs, saTokenRefreshIntervalFlagName, DefaultSATokenRefresherIntervalInSecs, "Persistence agent service account token read interval in seconds. "+
179182
"Defines how often `/var/run/secrets/kubeflow/tokens/kubeflow-persistent_agent-api-token` to be read")
180-
183+
flag.StringVar(&caCertPath, caCertPathFlagName, "", "The path to the CA certificate.")
181184
}

backend/src/apiserver/client_manager/client_manager.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func (c *ClientManager) init() {
208208

209209
c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams)
210210

211-
newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled())
211+
newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled(), common.GetCaCertPath())
212212

213213
if err != nil {
214214
glog.Fatalf("Failed to create metadata client. Error: %v", err)

backend/src/apiserver/common/config.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ const (
3636
MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT"
3737
MetadataTLSEnabled string = "METADATA_TLS_ENABLED"
3838
SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS"
39+
CaBundleMountPath string = "ARTIFACT_COPY_STEP_CABUNDLE_MOUNTPATH"
40+
CaBundleConfigMapKey string = "ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_KEY"
41+
CaBundleConfigMapName string = "ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_NAME"
3942
)
4043

4144
func IsPipelineVersionUpdatedByDefault() bool {
@@ -147,3 +150,13 @@ func GetSignedURLExpiryTimeSeconds() int {
147150
func GetMetadataTLSEnabled() bool {
148151
return GetBoolConfigWithDefault(MetadataTLSEnabled, DefaultMetadataTLSEnabled)
149152
}
153+
154+
func GetCaCertPath() string {
155+
caBundleMountPath := GetStringConfigWithDefault(CaBundleMountPath, "")
156+
if caBundleMountPath != "" {
157+
caBundleConfigMapKey := GetStringConfigWithDefault(CaBundleConfigMapKey, "")
158+
return caBundleMountPath + "/" + caBundleConfigMapKey
159+
} else {
160+
return ""
161+
}
162+
}

backend/src/common/util/service.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ package util
1616

1717
import (
1818
"crypto/tls"
19+
"crypto/x509"
1920
"fmt"
2021
"google.golang.org/grpc/credentials"
2122
"google.golang.org/grpc/credentials/insecure"
2223
"net/http"
24+
"os"
2325
"strings"
2426
"time"
2527

@@ -77,10 +79,23 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) (
7779
return clientSet, config, namespace, nil
7880
}
7981

80-
func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) {
82+
func GetRpcConnection(address string, tlsEnabled bool, caCertPath string) (*grpc.ClientConn, error) {
8183
creds := insecure.NewCredentials()
8284
if tlsEnabled {
83-
config := &tls.Config{}
85+
if caCertPath == "" {
86+
return nil, errors.New("CA cert path is empty")
87+
}
88+
89+
caCert, err := os.ReadFile(caCertPath)
90+
if err != nil {
91+
return nil, errors.Wrap(err, "Encountered error when reading CA cert path for creating a metadata client.")
92+
}
93+
caCertPool := x509.NewCertPool()
94+
caCertPool.AppendCertsFromPEM(caCert)
95+
96+
config := &tls.Config{
97+
RootCAs: caCertPool,
98+
}
8499
creds = credentials.NewTLS(config)
85100
}
86101

backend/src/v2/cmd/driver/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ var (
7171

7272
mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
7373
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
74+
caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate.")
7475
)
7576

7677
// func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) {
@@ -176,6 +177,7 @@ func drive() (err error) {
176177
MLMDServerAddress: *mlmdServerAddress,
177178
MLMDServerPort: *mlmdServerPort,
178179
MLMDTLSEnabled: metadataTLSEnabled,
180+
CaCertPath: *caCertPath,
179181
}
180182
var execution *driver.Execution
181183
var driverErr error
@@ -307,5 +309,5 @@ func newMlmdClient() (*metadata.Client, error) {
307309
return nil, err
308310
}
309311

310-
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled)
312+
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled, *caCertPath)
311313
}

backend/src/v2/cmd/launcher-v2/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ var (
4444
mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.")
4545
mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
4646
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
47+
caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate.")
4748
)
4849

4950
func main() {
@@ -88,6 +89,7 @@ func run() error {
8889
RunID: *runID,
8990
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
9091
MetadataTLSEnabled: metadataServiceTLSEnabled,
92+
CaCertPath: *caCertPath,
9193
}
9294

9395
switch *executorType {

backend/src/v2/compiler/argocompiler/common.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package argocompiler
1717
import (
1818
"fmt"
1919
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
20+
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
2021
k8score "k8s.io/api/core/v1"
2122
"os"
2223
"strconv"
@@ -98,9 +99,9 @@ func GetMLPipelineServicePortGRPC() string {
9899
// ConfigureCABundle adds CABundle environment variables and volume mounts
99100
// if CA Bundle env vars are specified.
100101
func ConfigureCABundle(tmpl *wfapi.Template) {
101-
caBundleCfgMapName := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_NAME")
102-
caBundleCfgMapKey := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_KEY")
103-
caBundleMountPath := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_MOUNTPATH")
102+
caBundleCfgMapName := os.Getenv(common.CaBundleConfigMapName)
103+
caBundleCfgMapKey := os.Getenv(common.CaBundleConfigMapKey)
104+
caBundleMountPath := os.Getenv(common.CaBundleMountPath)
104105
if caBundleCfgMapName != "" && caBundleCfgMapKey != "" {
105106
caFile := fmt.Sprintf("%s/%s", caBundleMountPath, caBundleCfgMapKey)
106107
var certDirectories = []string{

backend/src/v2/compiler/argocompiler/container.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ func (c *workflowCompiler) addContainerDriverTemplate() string {
171171
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
172172
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
173173
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
174+
"--ca_cert_path", common.GetCaCertPath(),
174175
},
175176
Resources: driverResources,
176177
},

backend/src/v2/compiler/argocompiler/dag.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
447447
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
448448
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
449449
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
450+
"--ca_cert_path", common.GetCaCertPath(),
450451
},
451452
Resources: driverResources,
452453
},

0 commit comments

Comments
 (0)