Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/Dockerfile.driver
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RUN GO111MODULE=on go mod download

COPY . .

RUN GO111MODULE=on CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -tags netgo -gcflags="${GCFLAGS}" -ldflags '-extldflags "-static"' -o /bin/driver ./backend/src/v2/cmd/driver/*.go
RUN GO111MODULE=on CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -tags netgo -gcflags="${GCFLAGS}" -ldflags '-extldflags "-static"' -o /bin/driver ./backend/src/driver/*.go

FROM alpine:3.19

Expand Down
44 changes: 44 additions & 0 deletions backend/src/driver/api/request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package api

type DriverPluginArgs struct {
CachedDecisionPath string `json:"cached_decision_path"`
Component string `json:"component,omitempty"`
Container string `json:"container,omitempty"`
RunMetadata string `json:"run_metadata,omitempty"`
DagExecutionID string `json:"dag_execution_id"`
IterationIndex string `json:"iteration_index"`
HttpProxy string `json:"http_proxy"`
HttpsProxy string `json:"https_proxy"`
NoProxy string `json:"no_proxy"`
KubernetesConfig string `json:"kubernetes_config,omitempty"`
RuntimeConfig string `json:"runtime_config,omitempty"`
PipelineName string `json:"pipeline_name"`
RunID string `json:"run_id"`
RunName string `json:"run_name"`
RunDisplayName string `json:"run_display_name"`
TaskName string `json:"task_name"`
Task string `json:"task"`
Type string `json:"type"`
CacheDisabledFlag bool `json:"cache_disabled"`
PublishLogs string `json:"publish_logs"`
ExecutionIdPath string `json:"execution_id_path"`
IterationCountPath string `json:"iteration_count_path"`
ConditionPath string `json:"condition_path"`
PodSpecPathPath string `json:"pod_spec_patch_path"`
}

type DriverPlugin struct {
DriverPlugin *DriverPluginContainer `json:"driver-plugin"`
}

type DriverPluginContainer struct {
Args *DriverPluginArgs `json:"args"`
}

type DriverTemplate struct {
Plugin *DriverPlugin `json:"plugin"`
}

type DriverRequest struct {
Template *DriverTemplate `json:"template"`
}
20 changes: 20 additions & 0 deletions backend/src/driver/api/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package api

type DriverResponse struct {
Node Node `json:"node"`
}

type Node struct {
Phase string `json:"phase"`
Outputs Outputs `json:"outputs"`
Message string `json:"message"`
}

type Outputs struct {
Parameters []Parameter `json:"parameters"`
}

type Parameter struct {
Name string `json:"name"`
Value string `json:"value"`
}
178 changes: 178 additions & 0 deletions backend/src/driver/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright 2021-2023 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main

import (
"bytes"
"encoding/json"
"flag"
"fmt"
"net/http"

"google.golang.org/protobuf/encoding/protojson"

"github.com/kubeflow/pipelines/backend/src/common/util"

"os"
"path/filepath"
"strconv"

"github.com/golang/glog"
"github.com/kubeflow/pipelines/backend/src/v2/driver"
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
)

const (
unsetProxyArgValue = "unset"
ROOT_DAG = "ROOT_DAG"
DAG = "DAG"
CONTAINER = "CONTAINER"
)

var (
logLevel = flag.String("log_level", "1", "The verbosity level to log.")

// config
mlmdServerAddress = flag.String("mlmd_server_address", "metadata-grpc-service", "MLMD server address")
mlmdServerPort = flag.String("mlmd_server_port", "8080", "MLMD server port")

serverPort = flag.String("server_port", ":8080", "Server port")
)

func main() {
flag.Parse()

glog.Infof("Setting log level to: '%s'", *logLevel)
err := flag.Set("v", *logLevel)
if err != nil {
glog.Warningf("Failed to set log level: %s", err.Error())
}

http.HandleFunc("/api/v1/template.execute", ExecutePlugin)
glog.Infof("Server started at http://localhost%v", *serverPort)
err = http.ListenAndServe(*serverPort, nil)
if err != nil {
glog.Warningf("Failed to start http server: %s", err.Error())
}
}

// Use WARNING default logging level to facilitate troubleshooting.
func init() {
flag.Set("logtostderr", "true")
// Change the WARNING to INFO level for debugging.
flag.Set("stderrthreshold", "WARNING")
}

func parseExecConfigJson(k8sExecConfigJson *string) (*kubernetesplatform.KubernetesExecutorConfig, error) {
var k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig
if *k8sExecConfigJson != "" {
glog.Infof("input kubernetesConfig:%s\n", prettyPrint(*k8sExecConfigJson))
k8sExecCfg = &kubernetesplatform.KubernetesExecutorConfig{}
if err := util.UnmarshalString(*k8sExecConfigJson, k8sExecCfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal Kubernetes config, error: %w\nKubernetesConfig: %v", err, k8sExecConfigJson)
}
}
return k8sExecCfg, nil
}

func handleExecution(execution *driver.Execution, driverType string, executionPaths *ExecutionPaths) error {
if execution.ID != 0 {
glog.Infof("output execution.ID=%v", execution.ID)
if executionPaths.ExecutionID != "" {
if err := writeFile(executionPaths.ExecutionID, []byte(fmt.Sprint(execution.ID))); err != nil {
return fmt.Errorf("failed to write execution ID to file: %w", err)
}
}
}
if execution.IterationCount != nil {
if err := writeFile(executionPaths.IterationCount, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil {
return fmt.Errorf("failed to write iteration count to file: %w", err)
}
} else {
if driverType == ROOT_DAG {
if err := writeFile(executionPaths.IterationCount, []byte("0")); err != nil {
return fmt.Errorf("failed to write iteration count to file: %w", err)
}
}
}
if execution.Cached != nil {
if err := writeFile(executionPaths.CachedDecision, []byte(strconv.FormatBool(*execution.Cached))); err != nil {
return fmt.Errorf("failed to write cached decision to file: %w", err)
}
}
if execution.Condition != nil {
if err := writeFile(executionPaths.Condition, []byte(strconv.FormatBool(*execution.Condition))); err != nil {
return fmt.Errorf("failed to write condition to file: %w", err)
}
} else {
// nil is a valid value for Condition
if driverType == ROOT_DAG || driverType == CONTAINER {
if err := writeFile(executionPaths.Condition, []byte("nil")); err != nil {
return fmt.Errorf("failed to write condition to file: %w", err)
}
}
}
if execution.PodSpecPatch != "" {
glog.Infof("output podSpecPatch=\n%s\n", execution.PodSpecPatch)
if executionPaths.PodSpecPatch == "" {
return fmt.Errorf("--pod_spec_patch_path is required for container executor drivers")
}
if err := writeFile(executionPaths.PodSpecPatch, []byte(execution.PodSpecPatch)); err != nil {
return fmt.Errorf("failed to write pod spec patch to file: %w", err)
}
}
if execution.ExecutorInput != nil {
executorInputBytes, err := protojson.Marshal(execution.ExecutorInput)
if err != nil {
return fmt.Errorf("failed to marshal ExecutorInput to JSON: %w", err)
}
executorInputJSON := string(executorInputBytes)
glog.Infof("output ExecutorInput:%s\n", prettyPrint(executorInputJSON))
}
return nil
}

func prettyPrint(jsonStr string) string {
var prettyJSON bytes.Buffer
err := json.Indent(&prettyJSON, []byte(jsonStr), "", " ")
if err != nil {
return jsonStr
}
return prettyJSON.String()
}

func writeFile(path string, data []byte) (err error) {
if path == "" {
return fmt.Errorf("path is not specified")
}
defer func() {
if err != nil {
err = fmt.Errorf("failed to write to %s: %w", path, err)
}
}()
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}

func newMlmdClient() (*metadata.Client, error) {
mlmdConfig := metadata.DefaultConfig()
if *mlmdServerAddress != "" && *mlmdServerPort != "" {
mlmdConfig.Address = *mlmdServerAddress
mlmdConfig.Port = *mlmdServerPort
}
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port)
}
File renamed without changes.
Loading
Loading