Skip to content

Commit b8bd210

Browse files
enable notebook task/workflow registration (#3275)
* [wip] enable notebook registration Signed-off-by: Niels Bantilan <[email protected]> * lint Signed-off-by: Niels Bantilan <[email protected]> * updates (#3276) * updates * makes sense now Signed-off-by: Yee Hing Tong <[email protected]> --------- Signed-off-by: Yee Hing Tong <[email protected]> --------- Signed-off-by: Niels Bantilan <[email protected]> Signed-off-by: Yee Hing Tong <[email protected]> Co-authored-by: Yee Hing Tong <[email protected]>
1 parent 62672bd commit b8bd210

File tree

1 file changed

+58
-35
lines changed

1 file changed

+58
-35
lines changed

flytekit/remote/remote.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from flytekit.configuration.file import ConfigFile
4141
from flytekit.constants import CopyFileDetection
4242
from flytekit.core import constants, utils
43+
from flytekit.core.array_node import ArrayNode
4344
from flytekit.core.array_node_map_task import ArrayNodeMapTask
4445
from flytekit.core.artifact import Artifact
4546
from flytekit.core.base_task import PythonTask
@@ -211,6 +212,7 @@ def _get_git_repo_url(source_path: str):
211212
return ""
212213

213214

215+
@functools.lru_cache
214216
def _get_pickled_target_dict(
215217
root_entity: typing.Union[WorkflowBase, PythonTask],
216218
) -> typing.Tuple[bytes, PickledEntity]:
@@ -240,6 +242,10 @@ def _get_pickled_target_dict(
240242
f"Eager tasks are not supported in interactive mode. {entity.name} is an eager task."
241243
)
242244

245+
if isinstance(entity, ArrayNode):
246+
# extract WorkflowBase from ArrayNode
247+
entity = entity.target.workflow
248+
243249
if isinstance(entity, PythonTask):
244250
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
245251
if isinstance(entity, ArrayNodeMapTask):
@@ -1062,6 +1068,22 @@ def register_task(
10621068
domain=self.default_domain,
10631069
)
10641070

1071+
if self.interactive_mode_enabled:
1072+
md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity)
1073+
if version is None:
1074+
version = self._version_from_hash(
1075+
md5_bytes,
1076+
serialization_settings,
1077+
entity.python_interface.default_inputs_as_kwargs,
1078+
*FlyteRemote._get_image_names(entity),
1079+
*FlyteRemote._get_pod_template_hash(entity),
1080+
)
1081+
1082+
serialization_settings.fast_serialization_settings = self._pickle_and_upload_entity(
1083+
entity,
1084+
pickled_target_dict,
1085+
)
1086+
10651087
ident = run_sync(self._serialize_and_register, entity=entity, settings=serialization_settings, version=version)
10661088

10671089
ft = self.fetch_task(
@@ -1097,6 +1119,22 @@ def register_workflow(
10971119
domain=self.default_domain,
10981120
)
10991121

1122+
if self.interactive_mode_enabled:
1123+
md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity)
1124+
if version is None:
1125+
version = self._version_from_hash(
1126+
md5_bytes,
1127+
serialization_settings,
1128+
entity.python_interface.default_inputs_as_kwargs,
1129+
*FlyteRemote._get_image_names(entity),
1130+
*FlyteRemote._get_pod_template_hash(entity),
1131+
)
1132+
1133+
serialization_settings.fast_serialization_settings = self._pickle_and_upload_entity(
1134+
entity,
1135+
pickled_target_dict,
1136+
)
1137+
11001138
version, _ = self._resolve_version(version, entity, serialization_settings)
11011139

11021140
ident = run_sync(
@@ -1469,32 +1507,25 @@ def register_launch_plan(
14691507

14701508
version, _ = self._resolve_version(version, entity, serialization_settings)
14711509

1472-
if self._wf_exists(
1510+
if not self._wf_exists(
14731511
name=entity.workflow.name,
14741512
version=version,
14751513
project=serialization_settings.project,
14761514
domain=serialization_settings.domain,
14771515
):
1478-
# Underlying workflow, exists, only register the launch plan itself
1479-
launch_plan_model = get_serializable(
1480-
OrderedDict(), settings=serialization_settings, entity=entity, options=options
1481-
)
1482-
ident = self.raw_register(
1483-
launch_plan_model, serialization_settings, version, create_default_launchplan=False
1484-
)
1485-
if ident is None:
1486-
raise ValueError("Failed to register launch plan, identifier returned was empty...")
1487-
else:
1488-
# Register the launch and everything under it
1489-
ident = run_sync(
1490-
self._serialize_and_register,
1491-
entity,
1492-
serialization_settings,
1493-
version,
1494-
options,
1495-
False,
1516+
# If workflow doesn't exist, register it first
1517+
self.register_workflow(
1518+
entity.workflow, serialization_settings, version, default_launch_plan=False, options=options
14961519
)
14971520

1521+
# Underlying workflow, exists, only register the launch plan itself
1522+
launch_plan_model = get_serializable(
1523+
OrderedDict(), settings=serialization_settings, entity=entity, options=options
1524+
)
1525+
ident = self.raw_register(launch_plan_model, serialization_settings, version, create_default_launchplan=False)
1526+
if ident is None:
1527+
raise ValueError("Failed to register launch plan, identifier returned was empty...")
1528+
14981529
flp = self.fetch_launch_plan(ident.project, ident.domain, ident.name, ident.version)
14991530
flp.python_interface = entity.python_interface
15001531
return flp
@@ -2196,17 +2227,14 @@ def execute_local_task(
21962227
domain=domain or self._default_domain,
21972228
version=version,
21982229
)
2199-
version, pickled_target_dict = self._resolve_version(version, entity, ss)
2230+
version, _ = self._resolve_version(version, entity, ss)
22002231

22012232
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
22022233
resolved_identifiers_dict = asdict(resolved_identifiers)
22032234
try:
22042235
flyte_task: FlyteTask = self.fetch_task(**resolved_identifiers_dict)
22052236
flyte_task.python_interface = entity.python_interface
22062237
except FlyteEntityNotExistException:
2207-
if self.interactive_mode_enabled:
2208-
ss.fast_serialization_settings = self._pickle_and_upload_entity(entity, pickled_target_dict)
2209-
22102238
flyte_task: FlyteTask = self.register_task(entity, ss, version)
22112239

22122240
return self.execute(
@@ -2280,16 +2308,8 @@ def execute_local_workflow(
22802308
domain=domain or self._default_domain,
22812309
version=version,
22822310
)
2283-
pickled_target_dict = None
22842311
if version is None and self.interactive_mode_enabled:
2285-
md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity)
2286-
version = self._version_from_hash(
2287-
md5_bytes,
2288-
ss,
2289-
entity.python_interface.default_inputs_as_kwargs,
2290-
*FlyteRemote._get_image_names(entity),
2291-
*FlyteRemote._get_pod_template_hash(entity),
2292-
)
2312+
version, _ = self._resolve_version(version, entity, ss)
22932313

22942314
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
22952315
resolved_identifiers_dict = asdict(resolved_identifiers)
@@ -2300,9 +2320,12 @@ def execute_local_workflow(
23002320
self.fetch_workflow(**resolved_identifiers_dict)
23012321
except FlyteEntityNotExistException:
23022322
logger.info("Registering workflow because it wasn't found in Flyte Admin.")
2303-
if self.interactive_mode_enabled:
2304-
ss.fast_serialization_settings = self._pickle_and_upload_entity(entity, pickled_target_dict)
2305-
self.register_workflow(entity, ss, version=version, options=options)
2323+
self.register_workflow(
2324+
entity,
2325+
ss,
2326+
version=version,
2327+
options=options,
2328+
)
23062329

23072330
try:
23082331
flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict)

0 commit comments

Comments
 (0)