|
2 | 2 | from collections import OrderedDict |
3 | 3 |
|
4 | 4 | import flytekit.configuration |
5 | | -from flytekit import ContainerTask, Resources |
| 5 | +from flytekit import ContainerTask, Resources, PodTemplate |
6 | 6 | from flytekit.configuration import FastSerializationSettings, Image, ImageConfig |
7 | 7 | from flytekit.core.base_task import kwtypes |
8 | 8 | from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan |
|
12 | 12 | from flytekit.deck import Deck |
13 | 13 | from flytekit.models.core import identifier as identifier_models |
14 | 14 | from flytekit.models.task import Resources as resource_model |
15 | | -from flytekit.tools.translator import get_serializable, Options |
| 15 | +from flytekit.tools.translator import get_serializable |
| 16 | +from kubernetes import client |
| 17 | +from kubernetes.client import V1PodSpec, V1Container |
16 | 18 | import pytest |
17 | 19 |
|
18 | 20 | default_img = Image(name="default", fqn="test", tag="tag") |
@@ -183,3 +185,52 @@ def morning_greeter_caller(day_of_week: str) -> str: |
183 | 185 | assert len(task_spec.template.interface.outputs) == 1 |
184 | 186 | assert len(task_spec.template.nodes) == 1 |
185 | 187 | assert len(task_spec.template.nodes[0].inputs) == 2 |
| 188 | + |
| 189 | +@pytest.mark.parametrize( |
| 190 | + "fast_registration_enabled", |
| 191 | + [ |
| 192 | + pytest.param( |
| 193 | + True, id="fast registration enabled" |
| 194 | + ), |
| 195 | + pytest.param( |
| 196 | + False, id="fast registration disabled" |
| 197 | + ), |
| 198 | + ], |
| 199 | +) |
| 200 | +def test_task_with_pod_template_override(fast_registration_enabled: bool): |
| 201 | + |
| 202 | + custom_pod_template = PodTemplate(pod_spec=V1PodSpec( |
| 203 | + containers=[ |
| 204 | + V1Container( |
| 205 | + name="primary", |
| 206 | + env=[ |
| 207 | + client.V1EnvVar(name="MY_KEY", value="MY_VALUE"), |
| 208 | + ] |
| 209 | + ) |
| 210 | + ] |
| 211 | + )) |
| 212 | + |
| 213 | + @task |
| 214 | + def t(a: str) -> str: |
| 215 | + return a |
| 216 | + |
| 217 | + @workflow |
| 218 | + def wf(): |
| 219 | + t("Hello World").with_overrides(pod_template=custom_pod_template) |
| 220 | + |
| 221 | + settings = ( |
| 222 | + serialization_settings.new_builder() |
| 223 | + .with_fast_serialization_settings(FastSerializationSettings(enabled=fast_registration_enabled)) |
| 224 | + .build() |
| 225 | + ) |
| 226 | + |
| 227 | + task_spec = get_serializable(OrderedDict(), settings, wf) |
| 228 | + assert len(task_spec.template.nodes) == 1 |
| 229 | + node = task_spec.template.nodes[0] |
| 230 | + assert node.metadata.name == "t" |
| 231 | + assert node.task_node.overrides.pod_template is not None |
| 232 | + pod_template_override = node.task_node.overrides.pod_template |
| 233 | + assert pod_template_override.pod_spec # validate not empty |
| 234 | + assert len(pod_template_override.pod_spec['containers']) == 1 |
| 235 | + container = pod_template_override.pod_spec['containers'][0] |
| 236 | + assert container['env'] == [{'name': 'MY_KEY', 'value': 'MY_VALUE'}] |
0 commit comments