Skip to content

Commit 4827e55

Browse files
committed
feat: make container task return python type
-e Signed-off-by: machichima <nary12321@gmail.com>
1 parent 0f6dd70 commit 4827e55

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

flytekit/core/array_node_map_task.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
1818
from flytekit.core.interface import transform_interface_to_list_interface
1919
from flytekit.core.launch_plan import LaunchPlan
20-
from flytekit.core.promise import Promise, create_native_named_tuple, create_task_output
2120
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
2221
from flytekit.core.task import ReferenceTask
2322
from flytekit.core.type_engine import TypeEngine
@@ -380,13 +379,6 @@ def _raw_execute(self, **kwargs) -> Any:
380379
single_instance_inputs[k] = kwargs[k]
381380
try:
382381
o = self._run_task.execute(**single_instance_inputs)
383-
# For running container task in local execution, it will return
384-
# the LiteralMap. We need to convert it to native type here.
385-
if isinstance(o, _literal_models.LiteralMap):
386-
vals = [Promise(var, o.literals[var]) for var in o.literals.keys()]
387-
result = create_task_output(vals, self.python_interface)
388-
ctx = FlyteContextManager.current_context()
389-
o = create_native_named_tuple(ctx, result, self._run_task.python_interface)
390382
if outputs_expected:
391383
outputs.append(o)
392384
except Exception as exc:

flytekit/core/container_task.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from flytekit.image_spec.image_spec import ImageSpec
1515
from flytekit.loggers import logger
1616
from flytekit.models import task as _task_model
17-
from flytekit.models.literals import LiteralMap
1817
from flytekit.models.security import Secret, SecurityContext
1918

2019
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
@@ -254,14 +253,12 @@ def _get_output_dict(self, output_directory: str) -> Dict[str, Any]:
254253
output_dict[k] = self._convert_output_val_to_correct_type(output_val, output_type)
255254
return output_dict
256255

257-
def execute(self, **kwargs) -> LiteralMap:
256+
def execute(self, **kwargs) -> Any:
258257
try:
259258
import docker
260259
except ImportError:
261260
raise ImportError(DOCKER_IMPORT_ERROR_MESSAGE)
262261

263-
from flytekit.core.type_engine import TypeEngine
264-
265262
ctx = FlyteContext.current_context()
266263

267264
# Normalize the input and output directories
@@ -289,8 +286,12 @@ def execute(self, **kwargs) -> LiteralMap:
289286
container.wait()
290287

291288
output_dict = self._get_output_dict(output_directory)
292-
outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict)
293-
return outputs_literal_map
289+
if len(output_dict) == 0:
290+
return None
291+
elif len(output_dict) == 1:
292+
return list(output_dict.values())[0]
293+
elif len(output_dict) > 1:
294+
return tuple(output_dict.values())
294295

295296
def get_container(self, settings: SerializationSettings) -> _task_model.Container:
296297
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container

0 commit comments

Comments
 (0)