Skip to content

Commit 3388fc7

Browse files
author
arpechenin
committed
Central driver POC #12023
- Modify Argo compiler: generate a plugin template instead of a container Signed-off-by: arpechenin <[email protected]>
1 parent cd037e2 commit 3388fc7

20 files changed

+5192
-5976
lines changed

backend/src/v2/compiler/argocompiler/container.go

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,15 @@ func GetPipelineRunAsUser() *int64 {
145145
return &runAsUser
146146
}
147147

148-
func (c *workflowCompiler) containerDriverTask(name string, inputs containerDriverInputs) (*wfapi.DAGTask, *containerDriverOutputs) {
148+
func (c *workflowCompiler) containerDriverTask(name string, inputs containerDriverInputs) (*wfapi.DAGTask, *containerDriverOutputs, error) {
149+
template, err := c.addContainerDriverTemplate()
150+
if err != nil {
151+
return nil, nil, err
152+
}
153+
149154
dagTask := &wfapi.DAGTask{
150155
Name: name,
151-
Template: c.addContainerDriverTemplate(),
156+
Template: template,
152157
Arguments: wfapi.Arguments{
153158
Parameters: []wfapi.Parameter{
154159
{Name: paramComponent, Value: wfapi.AnyStringPtr(inputs.component)},
@@ -176,44 +181,44 @@ func (c *workflowCompiler) containerDriverTask(name string, inputs containerDriv
176181
cached: taskOutputParameter(name, paramCachedDecision),
177182
condition: taskOutputParameter(name, paramCondition),
178183
}
179-
return dagTask, outputs
184+
return dagTask, outputs, nil
180185
}
181186

182-
func (c *workflowCompiler) addContainerDriverTemplate() string {
187+
func (c *workflowCompiler) addContainerDriverTemplate() (string, error) {
183188
name := "system-container-driver"
184189
_, ok := c.templates[name]
185190
if ok {
186-
return name
187-
}
188-
189-
args := []string{
190-
"--type", "CONTAINER",
191-
"--pipeline_name", c.spec.GetPipelineInfo().GetName(),
192-
"--run_id", runID(),
193-
"--run_name", runResourceName(),
194-
"--run_display_name", c.job.DisplayName,
195-
"--dag_execution_id", inputValue(paramParentDagID),
196-
"--component", inputValue(paramComponent),
197-
"--task", inputValue(paramTask),
198-
"--task_name", inputValue(paramTaskName),
199-
"--container", inputValue(paramContainer),
200-
"--iteration_index", inputValue(paramIterationIndex),
201-
"--cached_decision_path", outputPath(paramCachedDecision),
202-
"--pod_spec_patch_path", outputPath(paramPodSpecPatch),
203-
"--condition_path", outputPath(paramCondition),
204-
"--kubernetes_config", inputValue(paramKubernetesConfig),
205-
"--http_proxy", proxy.GetConfig().GetHttpProxy(),
206-
"--https_proxy", proxy.GetConfig().GetHttpsProxy(),
207-
"--no_proxy", proxy.GetConfig().GetNoProxy(),
208-
}
209-
if c.cacheDisabled {
210-
args = append(args, "--cache_disabled")
211-
}
212-
if value, ok := os.LookupEnv(PipelineLogLevelEnvVar); ok {
213-
args = append(args, "--log_level", value)
214-
}
215-
if value, ok := os.LookupEnv(PublishLogsEnvVar); ok {
216-
args = append(args, "--publish_logs", value)
191+
return name, nil
192+
}
193+
194+
logLevel, _ := os.LookupEnv(PipelineLogLevelEnvVar)
195+
publishLogs, _ := os.LookupEnv(PublishLogsEnvVar)
196+
197+
driverPlugin, err := driverPlugin(map[string]interface{}{
198+
"type": "CONTAINER",
199+
"pipeline_name": c.spec.GetPipelineInfo().GetName(),
200+
"run_id": runID(),
201+
"run_name": runResourceName(),
202+
"run_display_name": c.job.DisplayName,
203+
"dag_execution_id": inputValue(paramParentDagID),
204+
"component": inputValue(paramComponent),
205+
"task": inputValue(paramTask),
206+
"task_name": inputValue(paramTaskName),
207+
"container": inputValue(paramContainer),
208+
"iteration_index": inputValue(paramIterationIndex),
209+
"cached_decision_path": outputPath(paramCachedDecision),
210+
"pod_spec_patch_path": outputPath(paramPodSpecPatch),
211+
"condition_path": outputPath(paramCondition),
212+
"kubernetes_config": inputValue(paramKubernetesConfig),
213+
"http_proxy": proxy.GetConfig().GetHttpProxy(),
214+
"https_proxy": proxy.GetConfig().GetHttpsProxy(),
215+
"no_proxy": proxy.GetConfig().GetNoProxy(),
216+
"cache_disabled": c.cacheDisabled,
217+
"log_level": logLevel,
218+
"publish_logs": publishLogs,
219+
})
220+
if err != nil {
221+
return "", fmt.Errorf("failed to create container driver template: %w", err)
217222
}
218223

219224
t := &wfapi.Template{
@@ -231,22 +236,16 @@ func (c *workflowCompiler) addContainerDriverTemplate() string {
231236
},
232237
Outputs: wfapi.Outputs{
233238
Parameters: []wfapi.Parameter{
234-
{Name: paramPodSpecPatch, ValueFrom: &wfapi.ValueFrom{Path: "/tmp/outputs/pod-spec-patch", Default: wfapi.AnyStringPtr("")}},
235-
{Name: paramCachedDecision, Default: wfapi.AnyStringPtr("false"), ValueFrom: &wfapi.ValueFrom{Path: "/tmp/outputs/cached-decision", Default: wfapi.AnyStringPtr("false")}},
236-
{Name: paramCondition, ValueFrom: &wfapi.ValueFrom{Path: "/tmp/outputs/condition", Default: wfapi.AnyStringPtr("true")}},
239+
{Name: paramPodSpecPatch, ValueFrom: &wfapi.ValueFrom{JSONPath: "$.pod-spec-patch", Default: wfapi.AnyStringPtr("")}},
240+
{Name: paramCachedDecision, Default: wfapi.AnyStringPtr("false"), ValueFrom: &wfapi.ValueFrom{JSONPath: "$.cached-decision", Default: wfapi.AnyStringPtr("false")}},
241+
{Name: paramCondition, ValueFrom: &wfapi.ValueFrom{JSONPath: "$.condition", Default: wfapi.AnyStringPtr("true")}},
237242
},
238243
},
239-
Container: &k8score.Container{
240-
Image: c.driverImage,
241-
Command: c.driverCommand,
242-
Args: args,
243-
Resources: driverResources,
244-
Env: proxy.GetConfig().GetEnvVars(),
245-
},
244+
Plugin: driverPlugin,
246245
}
247246
c.templates[name] = t
248247
c.wf.Spec.Templates = append(c.wf.Spec.Templates, *t)
249-
return name
248+
return name, nil
250249
}
251250

252251
type containerExecutorInputs struct {

backend/src/v2/compiler/argocompiler/dag.go

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
2525
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
2626
"github.com/kubeflow/pipelines/backend/src/v2/compiler"
27-
k8score "k8s.io/api/core/v1"
2827
"k8s.io/apimachinery/pkg/util/intstr"
2928
)
3029

@@ -283,7 +282,7 @@ func (c *workflowCompiler) task(name string, task *pipelinespec.PipelineTaskSpec
283282
driverTaskName := name + "-driver"
284283
// The following call will return an empty string for tasks without kubernetes-specific annotation.
285284
kubernetesConfigPlaceholder, _ := c.useKubernetesImpl(componentName)
286-
driver, driverOutputs := c.containerDriverTask(driverTaskName, containerDriverInputs{
285+
driver, driverOutputs, err := c.containerDriverTask(driverTaskName, containerDriverInputs{
287286
component: componentSpecPlaceholder,
288287
task: taskSpecJson,
289288
container: containerPlaceholder,
@@ -292,6 +291,9 @@ func (c *workflowCompiler) task(name string, task *pipelinespec.PipelineTaskSpec
292291
kubernetesConfig: kubernetesConfigPlaceholder,
293292
taskName: name,
294293
})
294+
if err != nil {
295+
return nil, err
296+
}
295297
if task.GetTriggerPolicy().GetCondition() == "" {
296298
driverOutputs.condition = ""
297299
}
@@ -531,9 +533,13 @@ func (c *workflowCompiler) dagDriverTask(name string, inputs dagDriverInputs) (*
531533
Value: wfapi.AnyStringPtr(inputs.taskName),
532534
})
533535
}
536+
dagDriverTemplate, err := c.addDAGDriverTemplate()
537+
if err != nil {
538+
return nil, nil, err
539+
}
534540
t := &wfapi.DAGTask{
535541
Name: name,
536-
Template: c.addDAGDriverTemplate(),
542+
Template: dagDriverTemplate,
537543
Arguments: wfapi.Arguments{
538544
Parameters: params,
539545
},
@@ -545,40 +551,40 @@ func (c *workflowCompiler) dagDriverTask(name string, inputs dagDriverInputs) (*
545551
}, nil
546552
}
547553

548-
func (c *workflowCompiler) addDAGDriverTemplate() string {
554+
func (c *workflowCompiler) addDAGDriverTemplate() (string, error) {
549555
name := "system-dag-driver"
550556
_, ok := c.templates[name]
551557
if ok {
552-
return name
553-
}
554-
555-
args := []string{
556-
"--type", inputValue(paramDriverType),
557-
"--pipeline_name", c.spec.GetPipelineInfo().GetName(),
558-
"--run_id", runID(),
559-
"--run_name", runResourceName(),
560-
"--run_display_name", c.job.DisplayName,
561-
"--dag_execution_id", inputValue(paramParentDagID),
562-
"--component", inputValue(paramComponent),
563-
"--task", inputValue(paramTask),
564-
"--task_name", inputValue(paramTaskName),
565-
"--runtime_config", inputValue(paramRuntimeConfig),
566-
"--iteration_index", inputValue(paramIterationIndex),
567-
"--execution_id_path", outputPath(paramExecutionID),
568-
"--iteration_count_path", outputPath(paramIterationCount),
569-
"--condition_path", outputPath(paramCondition),
570-
"--http_proxy", proxy.GetConfig().GetHttpProxy(),
571-
"--https_proxy", proxy.GetConfig().GetHttpsProxy(),
572-
"--no_proxy", proxy.GetConfig().GetNoProxy(),
573-
}
574-
if c.cacheDisabled {
575-
args = append(args, "--cache_disabled")
576-
}
577-
if value, ok := os.LookupEnv(PipelineLogLevelEnvVar); ok {
578-
args = append(args, "--log_level", value)
579-
}
580-
if value, ok := os.LookupEnv(PublishLogsEnvVar); ok {
581-
args = append(args, "--publish_logs", value)
558+
return name, nil
559+
}
560+
561+
logLevel, _ := os.LookupEnv(PipelineLogLevelEnvVar)
562+
publishLogs, _ := os.LookupEnv(PublishLogsEnvVar)
563+
564+
driverPlugin, err := driverPlugin(map[string]interface{}{
565+
"type": inputValue(paramDriverType),
566+
"pipeline_name": c.spec.GetPipelineInfo().GetName(),
567+
"run_id": runID(),
568+
"run_name": runResourceName(),
569+
"run_display_name": c.job.DisplayName,
570+
"dag_execution_id": inputValue(paramParentDagID),
571+
"component": inputValue(paramComponent),
572+
"task": inputValue(paramTask),
573+
"task_name": inputValue(paramTaskName),
574+
"runtime_config": inputValue(paramRuntimeConfig),
575+
"iteration_index": inputValue(paramIterationIndex),
576+
"execution_id_path": outputPath(paramExecutionID),
577+
"iteration_count_path": outputPath(paramIterationCount),
578+
"condition_path": outputPath(paramCondition),
579+
"http_proxy": proxy.GetConfig().GetHttpProxy(),
580+
"https_proxy": proxy.GetConfig().GetHttpsProxy(),
581+
"no_proxy": proxy.GetConfig().GetNoProxy(),
582+
"cache_disabled": c.cacheDisabled,
583+
"log_level": logLevel,
584+
"publish_logs": publishLogs,
585+
})
586+
if err != nil {
587+
return "", fmt.Errorf("failed to create dag driver template: %w", err)
582588
}
583589

584590
t := &wfapi.Template{
@@ -601,17 +607,11 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
601607
{Name: paramCondition, ValueFrom: &wfapi.ValueFrom{Path: "/tmp/outputs/condition", Default: wfapi.AnyStringPtr("true")}},
602608
},
603609
},
604-
Container: &k8score.Container{
605-
Image: c.driverImage,
606-
Command: c.driverCommand,
607-
Args: args,
608-
Resources: driverResources,
609-
Env: proxy.GetConfig().GetEnvVars(),
610-
},
610+
Plugin: driverPlugin,
611611
}
612612
c.templates[name] = t
613613
c.wf.Spec.Templates = append(c.wf.Spec.Templates, *t)
614-
return name
614+
return name, nil
615615
}
616616

617617
func addImplicitDependencies(dagSpec *pipelinespec.DagSpec) error {
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package argocompiler
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
7+
)
8+
9+
func driverPlugin(params map[string]interface{}) (*wfapi.Plugin, error) {
10+
pluginConfig := map[string]interface{}{
11+
"driver-plugin": map[string]interface{}{
12+
"args": params,
13+
},
14+
}
15+
jsonConfig, err := json.Marshal(pluginConfig)
16+
if err != nil {
17+
return nil, fmt.Errorf("driver plugin creation error: marshaling plugin config to JSON failed: %w", err)
18+
}
19+
return &wfapi.Plugin{Object: wfapi.Object{
20+
Value: jsonConfig,
21+
}}, nil
22+
}

0 commit comments

Comments
 (0)