Skip to content

Commit 40db6cb

Browse files
committed
add list or dict support for add toleration json
(cherry picked from commit fb18235) Signed-off-by: Humair Khan <[email protected]>
1 parent 6613a58 commit 40db6cb

File tree

2 files changed

+102
-12
lines changed

2 files changed

+102
-12
lines changed

kubernetes_platform/python/kfp/kubernetes/toleration.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def add_toleration_json(task: PipelineTask,
8989
task:
9090
Pipeline task.
9191
toleration_json:
92-
a toleration that is a pipeline input parameter.
92+
a toleration that is a pipeline input parameter, dict, or list.
9393
The input parameter must be of type dict or list.
9494
95-
If it is a dict, it must be a single toleration object.
95+
If it is an input parameter of type dict, it must be a single toleration object.
9696
For example a pipeline input parameter in this case could be:
9797
{
9898
"key": "key1",
@@ -101,7 +101,7 @@ def add_toleration_json(task: PipelineTask,
101101
"effect": "NoSchedule"
102102
}
103103
104-
If it is a list, it must be list of toleration objects.
104+
If it is an input parameter of type list, it must be list of toleration objects.
105105
For example a pipeline input parameter in this case could be:
106106
[
107107
{
@@ -116,18 +116,35 @@ def add_toleration_json(task: PipelineTask,
116116
"effect": "NoExecute"
117117
}
118118
]
119+
In the case of static list or dicts, the call wraps add_toleration.
119120
Returns:
120121
Task object with added toleration.
121122
"""
122-
if not isinstance(toleration_json, pipeline_channel.PipelineParameterChannel):
123-
raise TypeError("toleration_json must be a Pipeline Input Parameter.")
124123

125-
msg = common.get_existing_kubernetes_config_as_message(task)
126-
toleration = pb.Toleration()
127-
toleration.toleration_json.CopyFrom(
128-
common.parse_k8s_parameter_input(toleration_json, task)
129-
)
130-
msg.tolerations.append(toleration)
131-
task.platform_config["kubernetes"] = json_format.MessageToDict(msg)
124+
if isinstance(toleration_json, pipeline_channel.PipelineParameterChannel):
125+
msg = common.get_existing_kubernetes_config_as_message(task)
126+
toleration = pb.Toleration()
127+
toleration.toleration_json.CopyFrom(
128+
common.parse_k8s_parameter_input(toleration_json, task)
129+
)
130+
msg.tolerations.append(toleration)
131+
task.platform_config["kubernetes"] = json_format.MessageToDict(msg)
132+
elif isinstance(toleration_json, list):
133+
for toleration in toleration_json:
134+
_add_dict_toleration(task, toleration)
135+
elif isinstance(toleration_json, dict):
136+
_add_dict_toleration(task, toleration_json)
137+
else:
138+
raise ValueError("toleration_json must be a dict, list, or input parameter")
132139

133140
return task
141+
142+
def _add_dict_toleration(task: PipelineTask, toleration: dict):
143+
add_toleration(
144+
task,
145+
key=toleration.get("key"),
146+
value=toleration.get("value"),
147+
operator=toleration.get("operator"),
148+
effect=toleration.get("effect"),
149+
toleration_seconds=toleration.get("toleration_seconds"),
150+
)

kubernetes_platform/python/test/unit/test_tolerations.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,79 @@ def my_pipeline(toleration_input_1: dict, toleration_input_2: list):
254254
}
255255
}
256256

257+
def test_component_pipeline_input_three(self):
258+
# check list or dict types for add
259+
# toleration json
260+
@dsl.pipeline
261+
def my_pipeline(toleration_input: dict):
262+
t1 = comp()
263+
kubernetes.add_toleration_json(
264+
t1,
265+
toleration_json=toleration_input,
266+
)
267+
kubernetes.add_toleration_json(
268+
t1,
269+
toleration_json={
270+
"key": "key3",
271+
"operator": "Equal",
272+
"value": "value3",
273+
"effect": "NoSchedule"
274+
},
275+
)
276+
kubernetes.add_toleration_json(
277+
t1,
278+
toleration_json=[
279+
{
280+
"key": "key1",
281+
"operator": "Equal",
282+
"value": "value1",
283+
"effect": "NoSchedule"
284+
},
285+
{
286+
"key": "key2",
287+
"operator": "Exists",
288+
"effect": "NoExecute"
289+
}
290+
],
291+
)
292+
293+
assert json_format.MessageToDict(my_pipeline.platform_spec) == {
294+
'platforms': {
295+
'kubernetes': {
296+
'deploymentSpec': {
297+
'executors': {
298+
'exec-comp': {
299+
'tolerations': [
300+
{
301+
'tolerationJson': {
302+
'componentInputParameter': 'toleration_input'
303+
}
304+
},
305+
{
306+
'key': 'key3',
307+
'operator': 'Equal',
308+
'value': 'value3',
309+
'effect': 'NoSchedule',
310+
},
311+
{
312+
'key': 'key1',
313+
'operator': 'Equal',
314+
'value': 'value1',
315+
'effect': 'NoSchedule',
316+
},
317+
{
318+
'key': 'key2',
319+
'operator': 'Exists',
320+
'effect': 'NoExecute',
321+
},
322+
]
323+
},
324+
}
325+
}
326+
}
327+
}
328+
}
329+
257330
def test_component_upstream_input_one(self):
258331
# checks that upstream task input parameters
259332
# are supported

0 commit comments

Comments
 (0)