Skip to content

Commit 1fd50b9

Browse files
committed
Handle optional pipeline inputs in the driver (kubeflow#11788)
If the pipeline run is submitted without specifying an optional parameter and there is no default, it was not handled by the driver. The approach taken is explicitly set null for these values and let the driver handle if the component parameter has a default that can be used in the launcher. Signed-off-by: mprahl <[email protected]> (cherry picked from commit bb7a108)
1 parent 5a1fe40 commit 1fd50b9

File tree

4 files changed

+76
-10
lines changed

4 files changed

+76
-10
lines changed

backend/src/v2/compiler/argocompiler/argo.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,12 @@ func Compile(jobArg *pipelinespec.PipelineJob, kubernetesSpecArg *pipelinespec.S
7272
specParams := spec.GetRoot().GetInputDefinitions().GetParameters()
7373
for name, param := range specParams {
7474
_, ok := job.RuntimeConfig.ParameterValues[name]
75-
if !ok && param.GetDefaultValue() != nil {
76-
job.RuntimeConfig.ParameterValues[name] = param.GetDefaultValue()
75+
if !ok {
76+
if param.GetDefaultValue() != nil {
77+
job.RuntimeConfig.ParameterValues[name] = param.GetDefaultValue()
78+
} else if param.IsOptional {
79+
job.RuntimeConfig.ParameterValues[name] = structpb.NewNullValue()
80+
}
7781
}
7882
}
7983

backend/src/v2/driver/driver.go

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package driver
1717
import (
1818
"context"
1919
"encoding/json"
20+
"errors"
2021
"fmt"
2122
"slices"
2223
"strconv"
@@ -51,6 +52,8 @@ var dummyImages = map[string]string{
5152
"argostub/deletepvc": "delete PVC",
5253
}
5354

55+
var ErrResolvedParameterNull = errors.New("the resolved input parameter is null")
56+
5457
// TODO(capri-xiyue): Move driver to component package
5558
// Driver options
5659
type Options struct {
@@ -191,6 +194,7 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio
191194
if err != nil {
192195
return nil, err
193196
}
197+
194198
executorInput := &pipelinespec.ExecutorInput{
195199
Inputs: &pipelinespec.ExecutorInput_Inputs{
196200
ParameterValues: opts.RuntimeConfig.GetParameterValues(),
@@ -803,7 +807,7 @@ func extendPodSpecPatch(
803807
for _, secretAsVolume := range kubernetesExecutorConfig.GetSecretAsVolume() {
804808
var secretName string
805809
if secretAsVolume.SecretNameParameter != nil {
806-
resolvedSecretName, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
810+
resolvedSecretName, err := resolveInputParameterStr(ctx, dag, pipeline, opts, mlmd,
807811
secretAsVolume.SecretNameParameter, inputParams)
808812
if err != nil {
809813
return fmt.Errorf("failed to resolve secret name: %w", err)
@@ -847,7 +851,7 @@ func extendPodSpecPatch(
847851

848852
var secretName string
849853
if secretAsEnv.SecretNameParameter != nil {
850-
resolvedSecretName, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
854+
resolvedSecretName, err := resolveInputParameterStr(ctx, dag, pipeline, opts, mlmd,
851855
secretAsEnv.SecretNameParameter, inputParams)
852856
if err != nil {
853857
return fmt.Errorf("failed to resolve secret name: %w", err)
@@ -869,7 +873,7 @@ func extendPodSpecPatch(
869873
for _, configMapAsVolume := range kubernetesExecutorConfig.GetConfigMapAsVolume() {
870874
var configMapName string
871875
if configMapAsVolume.ConfigMapNameParameter != nil {
872-
resolvedSecretName, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
876+
resolvedSecretName, err := resolveInputParameterStr(ctx, dag, pipeline, opts, mlmd,
873877
configMapAsVolume.ConfigMapNameParameter, inputParams)
874878
if err != nil {
875879
return fmt.Errorf("failed to resolve configmap name: %w", err)
@@ -915,7 +919,7 @@ func extendPodSpecPatch(
915919

916920
var configMapName string
917921
if configMapAsEnv.ConfigMapNameParameter != nil {
918-
resolvedSecretName, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
922+
resolvedSecretName, err := resolveInputParameterStr(ctx, dag, pipeline, opts, mlmd,
919923
configMapAsEnv.ConfigMapNameParameter, inputParams)
920924
if err != nil {
921925
return fmt.Errorf("failed to resolve configmap name: %w", err)
@@ -937,7 +941,7 @@ func extendPodSpecPatch(
937941
for _, imagePullSecret := range kubernetesExecutorConfig.GetImagePullSecret() {
938942
var secretName string
939943
if imagePullSecret.SecretNameParameter != nil {
940-
resolvedSecretName, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
944+
resolvedSecretName, err := resolveInputParameterStr(ctx, dag, pipeline, opts, mlmd,
941945
imagePullSecret.SecretNameParameter, inputParams)
942946
if err != nil {
943947
return fmt.Errorf("failed to resolve image pull secret name: %w", err)
@@ -1497,8 +1501,20 @@ func resolveInputs(
14971501
for name, paramSpec := range task.GetInputs().GetParameters() {
14981502
v, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd, paramSpec, inputParams)
14991503
if err != nil {
1504+
if !errors.Is(err, ErrResolvedParameterNull) {
1505+
return nil, err
1506+
}
1507+
1508+
componentParam, ok := opts.Component.GetInputDefinitions().GetParameters()[name]
1509+
if ok && componentParam != nil && componentParam.IsOptional {
1510+
// If the resolved paramter was null and the component input parameter is optional, just skip setting
1511+
// it and the launcher will handle defaults.
1512+
continue
1513+
}
1514+
15001515
return nil, err
15011516
}
1517+
15021518
inputs.ParameterValues[name] = v
15031519
}
15041520

@@ -1515,7 +1531,9 @@ func resolveInputs(
15151531
}
15161532

15171533
// resolveInputParameter resolves an InputParameterSpec
1518-
// using a given input context via InputParams.
1534+
// using a given input context via InputParams. ErrResolvedParameterNull is returned if paramSpec
1535+
// is a component input parameter and parameter resolves to a null value (i.e. an optional pipeline input with no
1536+
// default). The caller can decide if this is allowed in that context.
15191537
func resolveInputParameter(
15201538
ctx context.Context,
15211539
dag *metadata.DAG,
@@ -1539,6 +1557,13 @@ func resolveInputParameter(
15391557
if !ok {
15401558
return nil, paramError(fmt.Errorf("parent DAG does not have input parameter %s", componentInput))
15411559
}
1560+
1561+
if _, isNullValue := v.GetKind().(*structpb.Value_NullValue); isNullValue {
1562+
// Null values are only allowed for optional pipeline input parameters with no values. The caller has this
1563+
// context to know if this is allowed.
1564+
return nil, fmt.Errorf("%w: %s", ErrResolvedParameterNull, componentInput)
1565+
}
1566+
15421567
return v, nil
15431568

15441569
// This is the case where the input comes from the output of an upstream task.
@@ -1588,6 +1613,33 @@ func resolveInputParameter(
15881613
}
15891614
}
15901615

1616+
// resolveInputParameterStr is like resolveInputParameter but returns an error if the resolved value is not a non-empty
1617+
// string.
1618+
func resolveInputParameterStr(
1619+
ctx context.Context,
1620+
dag *metadata.DAG,
1621+
pipeline *metadata.Pipeline,
1622+
opts Options,
1623+
mlmd *metadata.Client,
1624+
paramSpec *pipelinespec.TaskInputsSpec_InputParameterSpec,
1625+
inputParams map[string]*structpb.Value,
1626+
) (*structpb.Value, error) {
1627+
val, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd, paramSpec, inputParams)
1628+
if err != nil {
1629+
return nil, err
1630+
}
1631+
1632+
if typedVal, ok := val.GetKind().(*structpb.Value_StringValue); ok && typedVal != nil {
1633+
if typedVal.StringValue == "" {
1634+
return nil, fmt.Errorf("resolving input parameter with spec %s. Expected a non-empty string.", paramSpec)
1635+
}
1636+
} else {
1637+
return nil, fmt.Errorf("resolving input parameter with spec %s. Expected a string but got: %T", paramSpec, val.GetKind())
1638+
}
1639+
1640+
return val, nil
1641+
}
1642+
15911643
// resolveInputArtifact resolves an InputArtifactSpec
15921644
// using a given input context via inputArtifacts.
15931645
func resolveInputArtifact(
@@ -2366,7 +2418,7 @@ func makeVolumeMountPatch(
23662418
}
23672419
}
23682420

2369-
resolvedPvcName, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
2421+
resolvedPvcName, err := resolveInputParameterStr(ctx, dag, pipeline, opts, mlmd,
23702422
pvcNameParameter, inputParams)
23712423
if err != nil {
23722424
return nil, nil, fmt.Errorf("failed to resolve pvc name: %w", err)

samples/v2/component_with_optional_inputs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def component_op(
2323
input_str1: Optional[str] = 'string default value',
2424
input_str2: Optional[str] = None,
2525
input_str3: Optional[str] = None,
26+
input_str4_from_pipeline: Optional[str] = "Some component default",
27+
input_str5_from_pipeline: Optional[str] = "Some component default",
28+
input_str6_from_pipeline: Optional[str] = None,
2629
input_bool1: Optional[bool] = True,
2730
input_bool2: Optional[bool] = None,
2831
input_dict: Optional[Dict[str, int]] = {"a": 1},
@@ -32,6 +35,9 @@ def component_op(
3235
print(f'input_str1: {input_str1}, type: {type(input_str1)}')
3336
print(f'input_str2: {input_str2}, type: {type(input_str2)}')
3437
print(f'input_str3: {input_str3}, type: {type(input_str3)}')
38+
print(f'input_str4_from_pipeline: {input_str4_from_pipeline}, type: {type(input_str4_from_pipeline)}')
39+
print(f'input_str5_from_pipeline: {input_str5_from_pipeline}, type: {type(input_str5_from_pipeline)}')
40+
print(f'input_str6_from_pipeline: {input_str6_from_pipeline}, type: {type(input_str6_from_pipeline)}')
3541
print(f'input_bool1: {input_bool1}, type: {type(input_bool1)}')
3642
print(f'input_bool2: {input_bool2}, type: {type(input_bool2)}')
3743
print(f'input_bool: {input_dict}, type: {type(input_dict)}')
@@ -40,10 +46,13 @@ def component_op(
4046

4147

4248
@dsl.pipeline(name='v2-component-optional-input')
43-
def pipeline():
49+
def pipeline(input_str4: Optional[str] = None, input_str5: Optional[str] = "Some pipeline default", input_str6: Optional[str] = None):
4450
component_op(
4551
input_str1='Hello',
4652
input_str2='World',
53+
input_str4_from_pipeline=input_str4,
54+
input_str5_from_pipeline=input_str5,
55+
input_str6_from_pipeline=input_str6,
4756
)
4857

4958

samples/v2/component_with_optional_inputs_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun,
3838
'parameters': {
3939
'input_str1': 'Hello',
4040
'input_str2': 'World',
41+
'input_str5_from_pipeline': 'Some pipeline default',
4142
},
4243
},
4344
'outputs': {},

0 commit comments

Comments
 (0)