Skip to content

Commit d0a0706

Browse files
authored
feat(sdk): add custom task param support in command spec (kubeflow#1061)
1 parent d005b7b commit d0a0706

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

sdk/python/kfp_tekton/compiler/compiler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,13 +1006,23 @@ def map_cel_vars(a):
10061006
if i.get('image', '') in TEKTON_CUSTOM_TASK_IMAGES:
10071007
custom_task_args = {}
10081008
container_args = i.get('args', [])
1009+
custom_task_command = {}
1010+
container_command = i.get('command', [])
10091011
for index, item in enumerate(container_args):
10101012
if item.startswith('--'):
10111013
custom_task_args[item[2:]] = container_args[index + 1]
1014+
for index, item in enumerate(container_command):
1015+
if item.startswith('--'):
1016+
custom_task_command[item[2:]] = container_command[index + 1]
10121017
non_param_keys = ['name', 'apiVersion', 'kind', 'taskSpec', 'taskRef']
10131018
task_params = []
1019+
command_params = []
1020+
for key, value in custom_task_command.items():
1021+
task_params.append({'name': key, 'value': value})
1022+
# Parameters in command spec get higher priority
1023+
command_params.append(key)
10141024
for key, value in custom_task_args.items():
1015-
if key not in non_param_keys:
1025+
if key not in non_param_keys and key not in command_params:
10161026
task_params.append({'name': key, 'value': value})
10171027
task_orig_params = task_ref['params']
10181028
task_ref = {

sdk/python/tests/compiler/testdata/custom_task_params.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
custom_task_name = "some-custom-task"
2020
custom_task_api_version = "custom.tekton.dev/v1alpha1"
2121
custom_task_image = "some-image"
22-
custom_task_command = "cmd"
2322
custom_task_kind = "custom-task"
2423

2524

@@ -61,7 +60,7 @@ def custom_task(resource_label: str, foo: str, bar: Any, pi: float) -> dsl.Conta
6160
task = dsl.ContainerOp(
6261
name=custom_task_name,
6362
image=custom_task_image,
64-
command=[custom_task_command],
63+
command=["--name", foo],
6564
arguments=[
6665
"--apiVersion", custom_task_api_version,
6766
"--kind", custom_task_kind,

sdk/python/tests/compiler/testdata/custom_task_params_ref.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ spec:
5252
tasks:
5353
- name: some-custom-task
5454
params:
55+
- name: name
56+
value: $(params.foo)
5557
- name: foo
5658
value: $(params.foo)
5759
- name: bar

sdk/python/tests/compiler/testdata/custom_task_params_spec.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ spec:
5252
tasks:
5353
- name: some-custom-task
5454
params:
55+
- name: name
56+
value: $(params.foo)
5557
- name: foo
5658
value: $(params.foo)
5759
- name: bar

0 commit comments

Comments
 (0)