Skip to content

Commit d9466fe

Browse files
author
arpechenin
committed
Central driver POC kubeflow#12023
- Modify Argo compiler: generate a plugin template instead of a container - driver as a http server Signed-off-by: arpechenin <[email protected]>
1 parent cd037e2 commit d9466fe

27 files changed

+5723
-6329
lines changed

backend/src/driver/api/request.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package api
2+
3+
type DriverPluginArgs struct {
4+
CachedDecisionPath string `json:"cached_decision_path"`
5+
Component string `json:"component"`
6+
Container string `json:"container"`
7+
RunMetadata string `json:"run_metadata"`
8+
DagExecutionID string `json:"dag_execution_id"`
9+
IterationIndex string `json:"iteration_index"`
10+
HttpProxy string `json:"http_proxy"`
11+
HttpsProxy string `json:"https_proxy"`
12+
NoProxy string `json:"no_proxy"`
13+
KubernetesConfig string `json:"kubernetes_config"`
14+
RuntimeConfig string `json:"runtime_config"`
15+
PipelineName string `json:"pipeline_name"`
16+
RunID string `json:"run_id"`
17+
RunName string `json:"run_name"`
18+
RunDisplayName string `json:"run_display_name"`
19+
TaskName string `json:"ask_name"`
20+
Task string `json:"task"`
21+
Type string `json:"type"`
22+
CacheDisabledFlag string `json:"cache_disabled"`
23+
PublishLogs string `json:"publish_logs"`
24+
ExecutionIdPath string `json:"execution_id_path"`
25+
IterationCountPath string `json:"iteration_count_path"`
26+
ConditionPath string `json:"condition_path"`
27+
PodSpecPathPath string `json:"pod_spec_patch_path"`
28+
}
29+
30+
type DriverPlugin struct {
31+
DriverPlugin *DriverPluginArgs `json:"driver-plugin"`
32+
}
33+
34+
type DriverTemplate struct {
35+
Plugin *DriverPlugin `json:"plugin"`
36+
}
37+
38+
type DriverRequest struct {
39+
Template *DriverTemplate `json:"template"`
40+
}

backend/src/driver/api/response.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package api
2+
3+
type DriverResponse struct {
4+
Node Node `json:"node"`
5+
}
6+
7+
type Node struct {
8+
Phase string `json:"phase"`
9+
Outputs Outputs `json:"outputs"`
10+
Message string `json:"message"`
11+
}
12+
13+
type Outputs struct {
14+
Parameters []Parameter `json:"parameters"`
15+
}
16+
17+
type Parameter struct {
18+
Name string `json:"name"`
19+
Value string `json:"value"`
20+
}

backend/src/driver/main.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright 2021-2023 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package main
15+
16+
import (
17+
"bytes"
18+
"encoding/json"
19+
"flag"
20+
"fmt"
21+
"net/http"
22+
23+
"google.golang.org/protobuf/encoding/protojson"
24+
25+
"github.com/kubeflow/pipelines/backend/src/common/util"
26+
27+
"os"
28+
"path/filepath"
29+
"strconv"
30+
31+
"github.com/golang/glog"
32+
"github.com/kubeflow/pipelines/backend/src/v2/driver"
33+
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
34+
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
35+
)
36+
37+
const (
38+
driverTypeArg = "type"
39+
httpProxyArg = "http_proxy"
40+
httpsProxyArg = "https_proxy"
41+
noProxyArg = "no_proxy"
42+
unsetProxyArgValue = "unset"
43+
ROOT_DAG = "ROOT_DAG"
44+
DAG = "DAG"
45+
CONTAINER = "CONTAINER"
46+
)
47+
48+
var (
49+
logLevel = flag.String("log_level", "1", "The verbosity level to log.")
50+
51+
// config
52+
mlmdServerAddress = flag.String("mlmd_server_address", "", "MLMD server address")
53+
mlmdServerPort = flag.String("mlmd_server_port", "", "MLMD server port")
54+
)
55+
56+
func main() {
57+
flag.Parse()
58+
59+
glog.Infof("Setting log level to: '%s'", *logLevel)
60+
err := flag.Set("v", *logLevel)
61+
if err != nil {
62+
glog.Warningf("Failed to set log level: %s", err.Error())
63+
}
64+
65+
http.HandleFunc("/api/v1/template.execute", ExecutePlugin)
66+
fmt.Println("Server started at http://localhost:8080")
67+
http.ListenAndServe(":8080", nil)
68+
69+
if err != nil {
70+
glog.Exitf("%v", err)
71+
}
72+
}
73+
74+
// Use WARNING default logging level to facilitate troubleshooting.
75+
func init() {
76+
flag.Set("logtostderr", "true")
77+
// Change the WARNING to INFO level for debugging.
78+
flag.Set("stderrthreshold", "WARNING")
79+
}
80+
81+
func parseExecConfigJson(k8sExecConfigJson *string) (*kubernetesplatform.KubernetesExecutorConfig, error) {
82+
var k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig
83+
if *k8sExecConfigJson != "" {
84+
glog.Infof("input kubernetesConfig:%s\n", prettyPrint(*k8sExecConfigJson))
85+
k8sExecCfg = &kubernetesplatform.KubernetesExecutorConfig{}
86+
if err := util.UnmarshalString(*k8sExecConfigJson, k8sExecCfg); err != nil {
87+
return nil, fmt.Errorf("failed to unmarshal Kubernetes config, error: %w\nKubernetesConfig: %v", err, k8sExecConfigJson)
88+
}
89+
}
90+
return k8sExecCfg, nil
91+
}
92+
93+
func handleExecution(execution *driver.Execution, driverType string, executionPaths *ExecutionPaths) error {
94+
if execution.ID != 0 {
95+
glog.Infof("output execution.ID=%v", execution.ID)
96+
if executionPaths.ExecutionID != "" {
97+
if err := writeFile(executionPaths.ExecutionID, []byte(fmt.Sprint(execution.ID))); err != nil {
98+
return fmt.Errorf("failed to write execution ID to file: %w", err)
99+
}
100+
}
101+
}
102+
if execution.IterationCount != nil {
103+
if err := writeFile(executionPaths.IterationCount, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil {
104+
return fmt.Errorf("failed to write iteration count to file: %w", err)
105+
}
106+
} else {
107+
if driverType == ROOT_DAG {
108+
if err := writeFile(executionPaths.IterationCount, []byte("0")); err != nil {
109+
return fmt.Errorf("failed to write iteration count to file: %w", err)
110+
}
111+
}
112+
}
113+
if execution.Cached != nil {
114+
if err := writeFile(executionPaths.CachedDecision, []byte(strconv.FormatBool(*execution.Cached))); err != nil {
115+
return fmt.Errorf("failed to write cached decision to file: %w", err)
116+
}
117+
}
118+
if execution.Condition != nil {
119+
if err := writeFile(executionPaths.Condition, []byte(strconv.FormatBool(*execution.Condition))); err != nil {
120+
return fmt.Errorf("failed to write condition to file: %w", err)
121+
}
122+
} else {
123+
// nil is a valid value for Condition
124+
if driverType == ROOT_DAG || driverType == CONTAINER {
125+
if err := writeFile(executionPaths.Condition, []byte("nil")); err != nil {
126+
return fmt.Errorf("failed to write condition to file: %w", err)
127+
}
128+
}
129+
}
130+
if execution.PodSpecPatch != "" {
131+
glog.Infof("output podSpecPatch=\n%s\n", execution.PodSpecPatch)
132+
if executionPaths.PodSpecPatch == "" {
133+
return fmt.Errorf("--pod_spec_patch_path is required for container executor drivers")
134+
}
135+
if err := writeFile(executionPaths.PodSpecPatch, []byte(execution.PodSpecPatch)); err != nil {
136+
return fmt.Errorf("failed to write pod spec patch to file: %w", err)
137+
}
138+
}
139+
if execution.ExecutorInput != nil {
140+
executorInputBytes, err := protojson.Marshal(execution.ExecutorInput)
141+
if err != nil {
142+
return fmt.Errorf("failed to marshal ExecutorInput to JSON: %w", err)
143+
}
144+
executorInputJSON := string(executorInputBytes)
145+
glog.Infof("output ExecutorInput:%s\n", prettyPrint(executorInputJSON))
146+
}
147+
return nil
148+
}
149+
150+
func prettyPrint(jsonStr string) string {
151+
var prettyJSON bytes.Buffer
152+
err := json.Indent(&prettyJSON, []byte(jsonStr), "", " ")
153+
if err != nil {
154+
return jsonStr
155+
}
156+
return prettyJSON.String()
157+
}
158+
159+
func writeFile(path string, data []byte) (err error) {
160+
if path == "" {
161+
return fmt.Errorf("path is not specified")
162+
}
163+
defer func() {
164+
if err != nil {
165+
err = fmt.Errorf("failed to write to %s: %w", path, err)
166+
}
167+
}()
168+
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
169+
return err
170+
}
171+
return os.WriteFile(path, data, 0o644)
172+
}
173+
174+
func newMlmdClient() (*metadata.Client, error) {
175+
mlmdConfig := metadata.DefaultConfig()
176+
if *mlmdServerAddress != "" && *mlmdServerPort != "" {
177+
mlmdConfig.Address = *mlmdServerAddress
178+
mlmdConfig.Port = *mlmdServerPort
179+
}
180+
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port)
181+
}
File renamed without changes.

0 commit comments

Comments
 (0)