| 
 | 1 | +// Copyright 2021-2023 The Kubeflow Authors  | 
 | 2 | +//  | 
 | 3 | +// Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +// you may not use this file except in compliance with the License.  | 
 | 5 | +// You may obtain a copy of the License at  | 
 | 6 | +//  | 
 | 7 | +//	http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | +//  | 
 | 9 | +// Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +// distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 12 | +// See the License for the specific language governing permissions and  | 
 | 13 | +// limitations under the License.  | 
 | 14 | +package main  | 
 | 15 | + | 
 | 16 | +import (  | 
 | 17 | +	"bytes"  | 
 | 18 | +	"encoding/json"  | 
 | 19 | +	"flag"  | 
 | 20 | +	"fmt"  | 
 | 21 | +	"net/http"  | 
 | 22 | + | 
 | 23 | +	"google.golang.org/protobuf/encoding/protojson"  | 
 | 24 | + | 
 | 25 | +	"github.com/kubeflow/pipelines/backend/src/common/util"  | 
 | 26 | + | 
 | 27 | +	"os"  | 
 | 28 | +	"path/filepath"  | 
 | 29 | +	"strconv"  | 
 | 30 | + | 
 | 31 | +	"github.com/golang/glog"  | 
 | 32 | +	"github.com/kubeflow/pipelines/backend/src/v2/driver"  | 
 | 33 | +	"github.com/kubeflow/pipelines/backend/src/v2/metadata"  | 
 | 34 | +	"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"  | 
 | 35 | +)  | 
 | 36 | + | 
 | 37 | +const (  | 
 | 38 | +	unsetProxyArgValue = "unset"  | 
 | 39 | +	ROOT_DAG           = "ROOT_DAG"  | 
 | 40 | +	DAG                = "DAG"  | 
 | 41 | +	CONTAINER          = "CONTAINER"  | 
 | 42 | +)  | 
 | 43 | + | 
 | 44 | +var (  | 
 | 45 | +	logLevel = flag.String("log_level", "1", "The verbosity level to log.")  | 
 | 46 | + | 
 | 47 | +	// config  | 
 | 48 | +	mlmdServerAddress = flag.String("mlmd_server_address", "metadata-grpc-service", "MLMD server address")  | 
 | 49 | +	mlmdServerPort    = flag.String("mlmd_server_port", "8080", "MLMD server port")  | 
 | 50 | + | 
 | 51 | +	serverPort = flag.String("server_port", ":8080", "Server port")  | 
 | 52 | +)  | 
 | 53 | + | 
 | 54 | +func main() {  | 
 | 55 | +	flag.Parse()  | 
 | 56 | + | 
 | 57 | +	glog.Infof("Setting log level to: '%s'", *logLevel)  | 
 | 58 | +	err := flag.Set("v", *logLevel)  | 
 | 59 | +	if err != nil {  | 
 | 60 | +		glog.Warningf("Failed to set log level: %s", err.Error())  | 
 | 61 | +	}  | 
 | 62 | + | 
 | 63 | +	http.HandleFunc("/api/v1/template.execute", ExecutePlugin)  | 
 | 64 | +	glog.Infof("Server started at http://localhost%v", *serverPort)  | 
 | 65 | +	err = http.ListenAndServe(*serverPort, nil)  | 
 | 66 | +	if err != nil {  | 
 | 67 | +		glog.Warningf("Failed to start http server: %s", err.Error())  | 
 | 68 | +	}  | 
 | 69 | +}  | 
 | 70 | + | 
 | 71 | +// Use WARNING default logging level to facilitate troubleshooting.  | 
 | 72 | +func init() {  | 
 | 73 | +	flag.Set("logtostderr", "true")  | 
 | 74 | +	// Change the WARNING to INFO level for debugging.  | 
 | 75 | +	flag.Set("stderrthreshold", "WARNING")  | 
 | 76 | +}  | 
 | 77 | + | 
 | 78 | +func parseExecConfigJson(k8sExecConfigJson *string) (*kubernetesplatform.KubernetesExecutorConfig, error) {  | 
 | 79 | +	var k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig  | 
 | 80 | +	if *k8sExecConfigJson != "" {  | 
 | 81 | +		glog.Infof("input kubernetesConfig:%s\n", prettyPrint(*k8sExecConfigJson))  | 
 | 82 | +		k8sExecCfg = &kubernetesplatform.KubernetesExecutorConfig{}  | 
 | 83 | +		if err := util.UnmarshalString(*k8sExecConfigJson, k8sExecCfg); err != nil {  | 
 | 84 | +			return nil, fmt.Errorf("failed to unmarshal Kubernetes config, error: %w\nKubernetesConfig: %v", err, k8sExecConfigJson)  | 
 | 85 | +		}  | 
 | 86 | +	}  | 
 | 87 | +	return k8sExecCfg, nil  | 
 | 88 | +}  | 
 | 89 | + | 
 | 90 | +func handleExecution(execution *driver.Execution, driverType string, executionPaths *ExecutionPaths) error {  | 
 | 91 | +	if execution.ID != 0 {  | 
 | 92 | +		glog.Infof("output execution.ID=%v", execution.ID)  | 
 | 93 | +		if executionPaths.ExecutionID != "" {  | 
 | 94 | +			if err := writeFile(executionPaths.ExecutionID, []byte(fmt.Sprint(execution.ID))); err != nil {  | 
 | 95 | +				return fmt.Errorf("failed to write execution ID to file: %w", err)  | 
 | 96 | +			}  | 
 | 97 | +		}  | 
 | 98 | +	}  | 
 | 99 | +	if execution.IterationCount != nil {  | 
 | 100 | +		if err := writeFile(executionPaths.IterationCount, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil {  | 
 | 101 | +			return fmt.Errorf("failed to write iteration count to file: %w", err)  | 
 | 102 | +		}  | 
 | 103 | +	} else {  | 
 | 104 | +		if driverType == ROOT_DAG {  | 
 | 105 | +			if err := writeFile(executionPaths.IterationCount, []byte("0")); err != nil {  | 
 | 106 | +				return fmt.Errorf("failed to write iteration count to file: %w", err)  | 
 | 107 | +			}  | 
 | 108 | +		}  | 
 | 109 | +	}  | 
 | 110 | +	if execution.Cached != nil {  | 
 | 111 | +		if err := writeFile(executionPaths.CachedDecision, []byte(strconv.FormatBool(*execution.Cached))); err != nil {  | 
 | 112 | +			return fmt.Errorf("failed to write cached decision to file: %w", err)  | 
 | 113 | +		}  | 
 | 114 | +	}  | 
 | 115 | +	if execution.Condition != nil {  | 
 | 116 | +		if err := writeFile(executionPaths.Condition, []byte(strconv.FormatBool(*execution.Condition))); err != nil {  | 
 | 117 | +			return fmt.Errorf("failed to write condition to file: %w", err)  | 
 | 118 | +		}  | 
 | 119 | +	} else {  | 
 | 120 | +		// nil is a valid value for Condition  | 
 | 121 | +		if driverType == ROOT_DAG || driverType == CONTAINER {  | 
 | 122 | +			if err := writeFile(executionPaths.Condition, []byte("nil")); err != nil {  | 
 | 123 | +				return fmt.Errorf("failed to write condition to file: %w", err)  | 
 | 124 | +			}  | 
 | 125 | +		}  | 
 | 126 | +	}  | 
 | 127 | +	if execution.PodSpecPatch != "" {  | 
 | 128 | +		glog.Infof("output podSpecPatch=\n%s\n", execution.PodSpecPatch)  | 
 | 129 | +		if executionPaths.PodSpecPatch == "" {  | 
 | 130 | +			return fmt.Errorf("--pod_spec_patch_path is required for container executor drivers")  | 
 | 131 | +		}  | 
 | 132 | +		if err := writeFile(executionPaths.PodSpecPatch, []byte(execution.PodSpecPatch)); err != nil {  | 
 | 133 | +			return fmt.Errorf("failed to write pod spec patch to file: %w", err)  | 
 | 134 | +		}  | 
 | 135 | +	}  | 
 | 136 | +	if execution.ExecutorInput != nil {  | 
 | 137 | +		executorInputBytes, err := protojson.Marshal(execution.ExecutorInput)  | 
 | 138 | +		if err != nil {  | 
 | 139 | +			return fmt.Errorf("failed to marshal ExecutorInput to JSON: %w", err)  | 
 | 140 | +		}  | 
 | 141 | +		executorInputJSON := string(executorInputBytes)  | 
 | 142 | +		glog.Infof("output ExecutorInput:%s\n", prettyPrint(executorInputJSON))  | 
 | 143 | +	}  | 
 | 144 | +	return nil  | 
 | 145 | +}  | 
 | 146 | + | 
 | 147 | +func prettyPrint(jsonStr string) string {  | 
 | 148 | +	var prettyJSON bytes.Buffer  | 
 | 149 | +	err := json.Indent(&prettyJSON, []byte(jsonStr), "", "  ")  | 
 | 150 | +	if err != nil {  | 
 | 151 | +		return jsonStr  | 
 | 152 | +	}  | 
 | 153 | +	return prettyJSON.String()  | 
 | 154 | +}  | 
 | 155 | + | 
 | 156 | +func writeFile(path string, data []byte) (err error) {  | 
 | 157 | +	if path == "" {  | 
 | 158 | +		return fmt.Errorf("path is not specified")  | 
 | 159 | +	}  | 
 | 160 | +	defer func() {  | 
 | 161 | +		if err != nil {  | 
 | 162 | +			err = fmt.Errorf("failed to write to %s: %w", path, err)  | 
 | 163 | +		}  | 
 | 164 | +	}()  | 
 | 165 | +	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {  | 
 | 166 | +		return err  | 
 | 167 | +	}  | 
 | 168 | +	return os.WriteFile(path, data, 0o644)  | 
 | 169 | +}  | 
 | 170 | + | 
 | 171 | +func newMlmdClient() (*metadata.Client, error) {  | 
 | 172 | +	mlmdConfig := metadata.DefaultConfig()  | 
 | 173 | +	if *mlmdServerAddress != "" && *mlmdServerPort != "" {  | 
 | 174 | +		mlmdConfig.Address = *mlmdServerAddress  | 
 | 175 | +		mlmdConfig.Port = *mlmdServerPort  | 
 | 176 | +	}  | 
 | 177 | +	return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port)  | 
 | 178 | +}  | 
0 commit comments