Skip to content

Commit 6613a58

Browse files
committed
add backend support for toleration lists.
clarify toleration json docs (cherry picked from commit 90909fc) Signed-off-by: Humair Khan <[email protected]>
1 parent dd29e6c commit 6613a58

File tree

4 files changed

+238
-11
lines changed

4 files changed

+238
-11
lines changed

backend/src/v2/driver/driver.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -785,19 +785,51 @@ func extendPodSpecPatch(
785785
if toleration != nil {
786786
k8sToleration := &k8score.Toleration{}
787787
if toleration.TolerationJson != nil {
788-
err := resolveK8sJsonParameter(ctx, opts, dag, pipeline, mlmd,
789-
toleration.GetTolerationJson(), inputParams, k8sToleration)
788+
resolvedParam, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
789+
toleration.GetTolerationJson(), inputParams)
790790
if err != nil {
791791
return fmt.Errorf("failed to resolve toleration: %w", err)
792792
}
793+
794+
// TolerationJson can be either a single toleration or list of tolerations
795+
// the field accepts both, and in both cases the tolerations are appended
796+
// to the total executor pod toleration list.
797+
var paramJSON []byte
798+
isSingleToleration := resolvedParam.GetStructValue() != nil
799+
isListToleration := resolvedParam.GetListValue() != nil
800+
if isSingleToleration {
801+
paramJSON, err = resolvedParam.GetStructValue().MarshalJSON()
802+
if err != nil {
803+
return err
804+
}
805+
var singleToleration k8score.Toleration
806+
if err = json.Unmarshal(paramJSON, &singleToleration); err != nil {
807+
return fmt.Errorf("failed to marshal single toleration to json: %w", err)
808+
}
809+
k8sTolerations = append(k8sTolerations, singleToleration)
810+
} else if isListToleration {
811+
paramJSON, err = resolvedParam.GetListValue().MarshalJSON()
812+
if err != nil {
813+
return err
814+
}
815+
var k8sTolerationsList []k8score.Toleration
816+
if err = json.Unmarshal(paramJSON, &k8sTolerationsList); err != nil {
817+
return fmt.Errorf("failed to marshal list toleration to json: %w", err)
818+
}
819+
k8sTolerations = append(k8sTolerations, k8sTolerationsList...)
820+
} else {
821+
return fmt.Errorf("encountered unexpected toleration proto value, "+
822+
"must be either struct or list type: %w", err)
823+
}
793824
} else {
794825
k8sToleration.Key = toleration.Key
795826
k8sToleration.Operator = k8score.TolerationOperator(toleration.Operator)
796827
k8sToleration.Value = toleration.Value
797828
k8sToleration.Effect = k8score.TaintEffect(toleration.Effect)
798829
k8sToleration.TolerationSeconds = toleration.TolerationSeconds
830+
k8sTolerations = append(k8sTolerations, *k8sToleration)
799831
}
800-
k8sTolerations = append(k8sTolerations, *k8sToleration)
832+
801833
}
802834
}
803835
podSpec.Tolerations = k8sTolerations

backend/src/v2/driver/driver_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,6 +2075,165 @@ func Test_extendPodSpecPatch_Tolerations(t *testing.T) {
20752075
}),
20762076
},
20772077
},
2078+
{
2079+
"Valid - toleration json - toleration list",
2080+
&kubernetesplatform.KubernetesExecutorConfig{
2081+
Tolerations: []*kubernetesplatform.Toleration{
2082+
{
2083+
TolerationJson: inputParamComponent("param_1"),
2084+
},
2085+
},
2086+
},
2087+
&k8score.PodSpec{
2088+
Containers: []k8score.Container{
2089+
{
2090+
Name: "main",
2091+
},
2092+
},
2093+
Tolerations: []k8score.Toleration{
2094+
{
2095+
Key: "key1",
2096+
Operator: "Equal",
2097+
Value: "value1",
2098+
Effect: "NoSchedule",
2099+
TolerationSeconds: int64Ptr(3601),
2100+
},
2101+
{
2102+
Key: "key2",
2103+
Operator: "Equal",
2104+
Value: "value2",
2105+
Effect: "NoSchedule",
2106+
TolerationSeconds: int64Ptr(3602),
2107+
},
2108+
{
2109+
Key: "key3",
2110+
Operator: "Equal",
2111+
Value: "value3",
2112+
Effect: "NoSchedule",
2113+
TolerationSeconds: int64Ptr(3603),
2114+
},
2115+
},
2116+
},
2117+
map[string]*structpb.Value{
2118+
"param_1": validListOfStructsOrPanic([]map[string]interface{}{
2119+
{
2120+
"key": "key1",
2121+
"operator": "Equal",
2122+
"value": "value1",
2123+
"effect": "NoSchedule",
2124+
"tolerationSeconds": 3601,
2125+
},
2126+
{
2127+
"key": "key2",
2128+
"operator": "Equal",
2129+
"value": "value2",
2130+
"effect": "NoSchedule",
2131+
"tolerationSeconds": 3602,
2132+
},
2133+
{
2134+
"key": "key3",
2135+
"operator": "Equal",
2136+
"value": "value3",
2137+
"effect": "NoSchedule",
2138+
"tolerationSeconds": 3603,
2139+
},
2140+
}),
2141+
},
2142+
},
2143+
{
2144+
"Valid - toleration json - list toleration & single toleration & constant toleration",
2145+
&kubernetesplatform.KubernetesExecutorConfig{
2146+
Tolerations: []*kubernetesplatform.Toleration{
2147+
{
2148+
TolerationJson: inputParamComponent("param_1"),
2149+
},
2150+
{
2151+
TolerationJson: inputParamComponent("param_2"),
2152+
},
2153+
{
2154+
Key: "key5",
2155+
Operator: "Equal",
2156+
Value: "value5",
2157+
Effect: "NoSchedule",
2158+
},
2159+
},
2160+
},
2161+
&k8score.PodSpec{
2162+
Containers: []k8score.Container{
2163+
{
2164+
Name: "main",
2165+
},
2166+
},
2167+
Tolerations: []k8score.Toleration{
2168+
{
2169+
Key: "key1",
2170+
Operator: "Equal",
2171+
Value: "value1",
2172+
Effect: "NoSchedule",
2173+
TolerationSeconds: int64Ptr(3601),
2174+
},
2175+
{
2176+
Key: "key2",
2177+
Operator: "Equal",
2178+
Value: "value2",
2179+
Effect: "NoSchedule",
2180+
TolerationSeconds: int64Ptr(3602),
2181+
},
2182+
{
2183+
Key: "key3",
2184+
Operator: "Equal",
2185+
Value: "value3",
2186+
Effect: "NoSchedule",
2187+
TolerationSeconds: int64Ptr(3603),
2188+
},
2189+
{
2190+
Key: "key4",
2191+
Operator: "Equal",
2192+
Value: "value4",
2193+
Effect: "NoSchedule",
2194+
TolerationSeconds: int64Ptr(3604),
2195+
},
2196+
{
2197+
Key: "key5",
2198+
Operator: "Equal",
2199+
Value: "value5",
2200+
Effect: "NoSchedule",
2201+
},
2202+
},
2203+
},
2204+
map[string]*structpb.Value{
2205+
"param_1": validListOfStructsOrPanic([]map[string]interface{}{
2206+
{
2207+
"key": "key1",
2208+
"operator": "Equal",
2209+
"value": "value1",
2210+
"effect": "NoSchedule",
2211+
"tolerationSeconds": 3601,
2212+
},
2213+
{
2214+
"key": "key2",
2215+
"operator": "Equal",
2216+
"value": "value2",
2217+
"effect": "NoSchedule",
2218+
"tolerationSeconds": 3602,
2219+
},
2220+
{
2221+
"key": "key3",
2222+
"operator": "Equal",
2223+
"value": "value3",
2224+
"effect": "NoSchedule",
2225+
"tolerationSeconds": 3603,
2226+
},
2227+
}),
2228+
"param_2": validValueStructOrPanic(map[string]interface{}{
2229+
"key": "key4",
2230+
"operator": "Equal",
2231+
"value": "value4",
2232+
"effect": "NoSchedule",
2233+
"tolerationSeconds": 3604,
2234+
}),
2235+
},
2236+
},
20782237
}
20792238
for _, tt := range tests {
20802239
t.Run(tt.name, func(t *testing.T) {
@@ -2554,6 +2713,18 @@ func Test_extendPodSpecPatch_GenericEphemeralVolume(t *testing.T) {
25542713
}
25552714
}
25562715

2716+
func validListOfStructsOrPanic(data []map[string]interface{}) *structpb.Value {
2717+
var listValues []*structpb.Value
2718+
for _, item := range data {
2719+
s, err := structpb.NewStruct(item)
2720+
if err != nil {
2721+
panic(err)
2722+
}
2723+
listValues = append(listValues, structpb.NewStructValue(s))
2724+
}
2725+
return structpb.NewListValue(&structpb.ListValue{Values: listValues})
2726+
}
2727+
25572728
func validValueStructOrPanic(data map[string]interface{}) *structpb.Value {
25582729
s, err := structpb.NewStruct(data)
25592730
if err != nil {

kubernetes_platform/python/kfp/kubernetes/toleration.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from typing import Optional, Union
1615

1716
from google.protobuf import json_format
@@ -82,21 +81,46 @@ def add_toleration(
8281

8382

8483
def add_toleration_json(task: PipelineTask,
85-
toleration_json: Union[pipeline_channel.PipelineParameterChannel, dict]
84+
toleration_json: Union[pipeline_channel.PipelineParameterChannel, list, dict]
8685
):
87-
"""Add a Pod Toleration in the form of a JSON to a task.
86+
"""Add a Pod Toleration in the form of a Pipeline Input JSON to a task.
8887
8988
Args:
9089
task:
9190
Pipeline task.
9291
toleration_json:
93-
a toleration provided as dict or input parameter. Takes
94-
precedence over other key, operator, value, effect,
95-
and toleration_seconds.
92+
a toleration that is a pipeline input parameter.
93+
The input parameter must be of type dict or list.
94+
95+
If it is a dict, it must be a single toleration object.
96+
For example a pipeline input parameter in this case could be:
97+
{
98+
"key": "key1",
99+
"operator": "Equal",
100+
"value": "value1",
101+
"effect": "NoSchedule"
102+
}
96103
104+
If it is a list, it must be list of toleration objects.
105+
For example a pipeline input parameter in this case could be:
106+
[
107+
{
108+
"key": "key1",
109+
"operator": "Equal",
110+
"value": "value1",
111+
"effect": "NoSchedule"
112+
},
113+
{
114+
"key": "key2",
115+
"operator": "Exists",
116+
"effect": "NoExecute"
117+
}
118+
]
97119
Returns:
98120
Task object with added toleration.
99121
"""
122+
if not isinstance(toleration_json, pipeline_channel.PipelineParameterChannel):
123+
raise TypeError("toleration_json must be a Pipeline Input Parameter.")
100124

101125
msg = common.get_existing_kubernetes_config_as_message(task)
102126
toleration = pb.Toleration()

kubernetes_platform/python/test/unit/test_tolerations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def test_component_pipeline_input_one(self):
173173
# checks that a pipeline input for
174174
# tasks is supported
175175
@dsl.pipeline
176-
def my_pipeline(toleration_input: str):
176+
def my_pipeline(toleration_input: dict):
177177
task = comp()
178178
kubernetes.add_toleration_json(
179179
task,
@@ -204,7 +204,7 @@ def test_component_pipeline_input_two(self):
204204
# checks that multiple pipeline inputs for
205205
# different tasks are supported
206206
@dsl.pipeline
207-
def my_pipeline(toleration_input_1: str, toleration_input_2: str):
207+
def my_pipeline(toleration_input_1: dict, toleration_input_2: list):
208208
t1 = comp()
209209
kubernetes.add_toleration_json(
210210
t1,

0 commit comments

Comments
 (0)