Skip to content

Commit 6a0127c

Browse files
authored
Fix issue where pod template override pod spec was missing (#3270)
* Fix issue where pod template override pod spec was conditionally included Signed-off-by: Jason Parraga <sovietaced@gmail.com> * Add unit test Signed-off-by: Jason Parraga <sovietaced@gmail.com> * Make unit test parameterized Signed-off-by: Jason Parraga <sovietaced@gmail.com> --------- Signed-off-by: Jason Parraga <sovietaced@gmail.com>
1 parent b8bd210 commit 6a0127c

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

flytekit/tools/translator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,9 @@ def get_serializable_node(
461461
elif isinstance(entity.flyte_entity, PythonTask):
462462
# handle pod template overrides
463463
override_pod_spec = {}
464-
if entity._pod_template is not None and settings.should_fast_serialize():
465-
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
464+
if entity._pod_template is not None:
465+
if settings.should_fast_serialize():
466+
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
466467
override_pod_spec = _serialize_pod_spec(
467468
entity._pod_template, entity.flyte_entity._get_container(settings), settings
468469
)

tests/flytekit/unit/test_translator.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import OrderedDict
33

44
import flytekit.configuration
5-
from flytekit import ContainerTask, Resources
5+
from flytekit import ContainerTask, Resources, PodTemplate
66
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
77
from flytekit.core.base_task import kwtypes
88
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
@@ -12,7 +12,9 @@
1212
from flytekit.deck import Deck
1313
from flytekit.models.core import identifier as identifier_models
1414
from flytekit.models.task import Resources as resource_model
15-
from flytekit.tools.translator import get_serializable, Options
15+
from flytekit.tools.translator import get_serializable
16+
from kubernetes import client
17+
from kubernetes.client import V1PodSpec, V1Container
1618
import pytest
1719

1820
default_img = Image(name="default", fqn="test", tag="tag")
@@ -183,3 +185,52 @@ def morning_greeter_caller(day_of_week: str) -> str:
183185
assert len(task_spec.template.interface.outputs) == 1
184186
assert len(task_spec.template.nodes) == 1
185187
assert len(task_spec.template.nodes[0].inputs) == 2
188+
189+
@pytest.mark.parametrize(
190+
"fast_registration_enabled",
191+
[
192+
pytest.param(
193+
True, id="fast registration enabled"
194+
),
195+
pytest.param(
196+
False, id="fast registration disabled"
197+
),
198+
],
199+
)
200+
def test_task_with_pod_template_override(fast_registration_enabled: bool):
201+
202+
custom_pod_template = PodTemplate(pod_spec=V1PodSpec(
203+
containers=[
204+
V1Container(
205+
name="primary",
206+
env=[
207+
client.V1EnvVar(name="MY_KEY", value="MY_VALUE"),
208+
]
209+
)
210+
]
211+
))
212+
213+
@task
214+
def t(a: str) -> str:
215+
return a
216+
217+
@workflow
218+
def wf():
219+
t("Hello World").with_overrides(pod_template=custom_pod_template)
220+
221+
settings = (
222+
serialization_settings.new_builder()
223+
.with_fast_serialization_settings(FastSerializationSettings(enabled=fast_registration_enabled))
224+
.build()
225+
)
226+
227+
task_spec = get_serializable(OrderedDict(), settings, wf)
228+
assert len(task_spec.template.nodes) == 1
229+
node = task_spec.template.nodes[0]
230+
assert node.metadata.name == "t"
231+
assert node.task_node.overrides.pod_template is not None
232+
pod_template_override = node.task_node.overrides.pod_template
233+
assert pod_template_override.pod_spec # validate not empty
234+
assert len(pod_template_override.pod_spec['containers']) == 1
235+
container = pod_template_override.pod_spec['containers'][0]
236+
assert container['env'] == [{'name': 'MY_KEY', 'value': 'MY_VALUE'}]

0 commit comments

Comments
 (0)