Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 58 additions & 35 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from flytekit.configuration.file import ConfigFile
from flytekit.constants import CopyFileDetection
from flytekit.core import constants, utils
from flytekit.core.array_node import ArrayNode
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.artifact import Artifact
from flytekit.core.base_task import PythonTask
Expand Down Expand Up @@ -211,6 +212,7 @@
return ""


@functools.lru_cache
def _get_pickled_target_dict(
root_entity: typing.Union[WorkflowBase, PythonTask],
) -> typing.Tuple[bytes, PickledEntity]:
Expand Down Expand Up @@ -240,6 +242,10 @@
f"Eager tasks are not supported in interactive mode. {entity.name} is an eager task."
)

if isinstance(entity, ArrayNode):
# extract WorkflowBase from ArrayNode
entity = entity.target.workflow

Check warning on line 247 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L247

Added line #L247 was not covered by tests

if isinstance(entity, PythonTask):
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
Expand Down Expand Up @@ -1062,6 +1068,22 @@
domain=self.default_domain,
)

if self.interactive_mode_enabled:
md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity)

Check warning on line 1072 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1072

Added line #L1072 was not covered by tests
if version is None:
version = self._version_from_hash(

Check warning on line 1074 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1074

Added line #L1074 was not covered by tests
md5_bytes,
serialization_settings,
entity.python_interface.default_inputs_as_kwargs,
*FlyteRemote._get_image_names(entity),
*FlyteRemote._get_pod_template_hash(entity),
)

serialization_settings.fast_serialization_settings = self._pickle_and_upload_entity(

Check warning on line 1082 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1082

Added line #L1082 was not covered by tests
entity,
pickled_target_dict,
)

ident = run_sync(self._serialize_and_register, entity=entity, settings=serialization_settings, version=version)

ft = self.fetch_task(
Expand Down Expand Up @@ -1097,6 +1119,22 @@
domain=self.default_domain,
)

if self.interactive_mode_enabled:
md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity)

Check warning on line 1123 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1123

Added line #L1123 was not covered by tests
if version is None:
version = self._version_from_hash(

Check warning on line 1125 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1125

Added line #L1125 was not covered by tests
md5_bytes,
serialization_settings,
entity.python_interface.default_inputs_as_kwargs,
*FlyteRemote._get_image_names(entity),
*FlyteRemote._get_pod_template_hash(entity),
)

serialization_settings.fast_serialization_settings = self._pickle_and_upload_entity(

Check warning on line 1133 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1133

Added line #L1133 was not covered by tests
entity,
pickled_target_dict,
)

version, _ = self._resolve_version(version, entity, serialization_settings)

ident = run_sync(
Expand Down Expand Up @@ -1469,32 +1507,25 @@

version, _ = self._resolve_version(version, entity, serialization_settings)

if self._wf_exists(
if not self._wf_exists(
name=entity.workflow.name,
version=version,
project=serialization_settings.project,
domain=serialization_settings.domain,
):
# Underlying workflow, exists, only register the launch plan itself
launch_plan_model = get_serializable(
OrderedDict(), settings=serialization_settings, entity=entity, options=options
)
ident = self.raw_register(
launch_plan_model, serialization_settings, version, create_default_launchplan=False
)
if ident is None:
raise ValueError("Failed to register launch plan, identifier returned was empty...")
else:
# Register the launch and everything under it
ident = run_sync(
self._serialize_and_register,
entity,
serialization_settings,
version,
options,
False,
# If workflow doesn't exist, register it first
self.register_workflow(

Check warning on line 1517 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1517

Added line #L1517 was not covered by tests
entity.workflow, serialization_settings, version, default_launch_plan=False, options=options
)

# Underlying workflow, exists, only register the launch plan itself
launch_plan_model = get_serializable(
OrderedDict(), settings=serialization_settings, entity=entity, options=options
)
ident = self.raw_register(launch_plan_model, serialization_settings, version, create_default_launchplan=False)
if ident is None:
raise ValueError("Failed to register launch plan, identifier returned was empty...")

Check warning on line 1527 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L1527

Added line #L1527 was not covered by tests

flp = self.fetch_launch_plan(ident.project, ident.domain, ident.name, ident.version)
flp.python_interface = entity.python_interface
return flp
Expand Down Expand Up @@ -2196,17 +2227,14 @@
domain=domain or self._default_domain,
version=version,
)
version, pickled_target_dict = self._resolve_version(version, entity, ss)
version, _ = self._resolve_version(version, entity, ss)

Check warning on line 2230 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L2230

Added line #L2230 was not covered by tests

resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
resolved_identifiers_dict = asdict(resolved_identifiers)
try:
flyte_task: FlyteTask = self.fetch_task(**resolved_identifiers_dict)
flyte_task.python_interface = entity.python_interface
except FlyteEntityNotExistException:
if self.interactive_mode_enabled:
ss.fast_serialization_settings = self._pickle_and_upload_entity(entity, pickled_target_dict)

flyte_task: FlyteTask = self.register_task(entity, ss, version)

return self.execute(
Expand Down Expand Up @@ -2280,16 +2308,8 @@
domain=domain or self._default_domain,
version=version,
)
pickled_target_dict = None
if version is None and self.interactive_mode_enabled:
md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity)
version = self._version_from_hash(
md5_bytes,
ss,
entity.python_interface.default_inputs_as_kwargs,
*FlyteRemote._get_image_names(entity),
*FlyteRemote._get_pod_template_hash(entity),
)
version, _ = self._resolve_version(version, entity, ss)

Check warning on line 2312 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L2312

Added line #L2312 was not covered by tests

resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
resolved_identifiers_dict = asdict(resolved_identifiers)
Expand All @@ -2300,9 +2320,12 @@
self.fetch_workflow(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
logger.info("Registering workflow because it wasn't found in Flyte Admin.")
if self.interactive_mode_enabled:
ss.fast_serialization_settings = self._pickle_and_upload_entity(entity, pickled_target_dict)
self.register_workflow(entity, ss, version=version, options=options)
self.register_workflow(

Check warning on line 2323 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L2323

Added line #L2323 was not covered by tests
entity,
ss,
version=version,
options=options,
)

try:
flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict)
Expand Down
Loading