@@ -89,10 +89,10 @@ def add_toleration_json(task: PipelineTask,
89
89
task:
90
90
Pipeline task.
91
91
toleration_json:
92
- a toleration that is a pipeline input parameter.
92
+ a toleration that is a pipeline input parameter, dict, or list .
93
93
The input parameter must be of type dict or list.
94
94
95
- If it is a dict, it must be a single toleration object.
95
+ If it is an input parameter of type dict, it must be a single toleration object.
96
96
For example a pipeline input parameter in this case could be:
97
97
{
98
98
"key": "key1",
@@ -101,7 +101,7 @@ def add_toleration_json(task: PipelineTask,
101
101
"effect": "NoSchedule"
102
102
}
103
103
104
- If it is a list, it must be list of toleration objects.
104
+ If it is an input parameter of type list, it must be list of toleration objects.
105
105
For example a pipeline input parameter in this case could be:
106
106
[
107
107
{
@@ -116,18 +116,35 @@ def add_toleration_json(task: PipelineTask,
116
116
"effect": "NoExecute"
117
117
}
118
118
]
119
+ In the case of static list or dicts, the call wraps add_toleration.
119
120
Returns:
120
121
Task object with added toleration.
121
122
"""
122
- if not isinstance (toleration_json , pipeline_channel .PipelineParameterChannel ):
123
- raise TypeError ("toleration_json must be a Pipeline Input Parameter." )
124
123
125
- msg = common .get_existing_kubernetes_config_as_message (task )
126
- toleration = pb .Toleration ()
127
- toleration .toleration_json .CopyFrom (
128
- common .parse_k8s_parameter_input (toleration_json , task )
129
- )
130
- msg .tolerations .append (toleration )
131
- task .platform_config ["kubernetes" ] = json_format .MessageToDict (msg )
124
+ if isinstance (toleration_json , pipeline_channel .PipelineParameterChannel ):
125
+ msg = common .get_existing_kubernetes_config_as_message (task )
126
+ toleration = pb .Toleration ()
127
+ toleration .toleration_json .CopyFrom (
128
+ common .parse_k8s_parameter_input (toleration_json , task )
129
+ )
130
+ msg .tolerations .append (toleration )
131
+ task .platform_config ["kubernetes" ] = json_format .MessageToDict (msg )
132
+ elif isinstance (toleration_json , list ):
133
+ for toleration in toleration_json :
134
+ _add_dict_toleration (task , toleration )
135
+ elif isinstance (toleration_json , dict ):
136
+ _add_dict_toleration (task , toleration_json )
137
+ else :
138
+ raise ValueError ("toleration_json must be a dict, list, or input parameter" )
132
139
133
140
return task
141
+
142
+ def _add_dict_toleration (task : PipelineTask , toleration : dict ):
143
+ add_toleration (
144
+ task ,
145
+ key = toleration .get ("key" ),
146
+ value = toleration .get ("value" ),
147
+ operator = toleration .get ("operator" ),
148
+ effect = toleration .get ("effect" ),
149
+ toleration_seconds = toleration .get ("toleration_seconds" ),
150
+ )
0 commit comments