4040from flytekit .configuration .file import ConfigFile
4141from flytekit .constants import CopyFileDetection
4242from flytekit .core import constants , utils
43+ from flytekit .core .array_node import ArrayNode
4344from flytekit .core .array_node_map_task import ArrayNodeMapTask
4445from flytekit .core .artifact import Artifact
4546from flytekit .core .base_task import PythonTask
@@ -211,6 +212,7 @@ def _get_git_repo_url(source_path: str):
211212 return ""
212213
213214
215+ @functools .lru_cache
214216def _get_pickled_target_dict (
215217 root_entity : typing .Union [WorkflowBase , PythonTask ],
216218) -> typing .Tuple [bytes , PickledEntity ]:
@@ -240,6 +242,10 @@ def _get_pickled_target_dict(
240242 f"Eager tasks are not supported in interactive mode. { entity .name } is an eager task."
241243 )
242244
245+ if isinstance (entity , ArrayNode ):
246+ # extract WorkflowBase from ArrayNode
247+ entity = entity .target .workflow
248+
243249 if isinstance (entity , PythonTask ):
244250 if isinstance (entity , (PythonAutoContainerTask , ArrayNodeMapTask )):
245251 if isinstance (entity , ArrayNodeMapTask ):
@@ -1062,6 +1068,22 @@ def register_task(
10621068 domain = self .default_domain ,
10631069 )
10641070
1071+ if self .interactive_mode_enabled :
1072+ md5_bytes , pickled_target_dict = _get_pickled_target_dict (entity )
1073+ if version is None :
1074+ version = self ._version_from_hash (
1075+ md5_bytes ,
1076+ serialization_settings ,
1077+ entity .python_interface .default_inputs_as_kwargs ,
1078+ * FlyteRemote ._get_image_names (entity ),
1079+ * FlyteRemote ._get_pod_template_hash (entity ),
1080+ )
1081+
1082+ serialization_settings .fast_serialization_settings = self ._pickle_and_upload_entity (
1083+ entity ,
1084+ pickled_target_dict ,
1085+ )
1086+
10651087 ident = run_sync (self ._serialize_and_register , entity = entity , settings = serialization_settings , version = version )
10661088
10671089 ft = self .fetch_task (
@@ -1097,6 +1119,22 @@ def register_workflow(
10971119 domain = self .default_domain ,
10981120 )
10991121
1122+ if self .interactive_mode_enabled :
1123+ md5_bytes , pickled_target_dict = _get_pickled_target_dict (entity )
1124+ if version is None :
1125+ version = self ._version_from_hash (
1126+ md5_bytes ,
1127+ serialization_settings ,
1128+ entity .python_interface .default_inputs_as_kwargs ,
1129+ * FlyteRemote ._get_image_names (entity ),
1130+ * FlyteRemote ._get_pod_template_hash (entity ),
1131+ )
1132+
1133+ serialization_settings .fast_serialization_settings = self ._pickle_and_upload_entity (
1134+ entity ,
1135+ pickled_target_dict ,
1136+ )
1137+
11001138 version , _ = self ._resolve_version (version , entity , serialization_settings )
11011139
11021140 ident = run_sync (
@@ -1469,32 +1507,25 @@ def register_launch_plan(
14691507
14701508 version , _ = self ._resolve_version (version , entity , serialization_settings )
14711509
1472- if self ._wf_exists (
1510+ if not self ._wf_exists (
14731511 name = entity .workflow .name ,
14741512 version = version ,
14751513 project = serialization_settings .project ,
14761514 domain = serialization_settings .domain ,
14771515 ):
1478- # Underlying workflow, exists, only register the launch plan itself
1479- launch_plan_model = get_serializable (
1480- OrderedDict (), settings = serialization_settings , entity = entity , options = options
1481- )
1482- ident = self .raw_register (
1483- launch_plan_model , serialization_settings , version , create_default_launchplan = False
1484- )
1485- if ident is None :
1486- raise ValueError ("Failed to register launch plan, identifier returned was empty..." )
1487- else :
1488- # Register the launch and everything under it
1489- ident = run_sync (
1490- self ._serialize_and_register ,
1491- entity ,
1492- serialization_settings ,
1493- version ,
1494- options ,
1495- False ,
1516+ # If workflow doesn't exist, register it first
1517+ self .register_workflow (
1518+ entity .workflow , serialization_settings , version , default_launch_plan = False , options = options
14961519 )
14971520
1521+ # Underlying workflow, exists, only register the launch plan itself
1522+ launch_plan_model = get_serializable (
1523+ OrderedDict (), settings = serialization_settings , entity = entity , options = options
1524+ )
1525+ ident = self .raw_register (launch_plan_model , serialization_settings , version , create_default_launchplan = False )
1526+ if ident is None :
1527+ raise ValueError ("Failed to register launch plan, identifier returned was empty..." )
1528+
14981529 flp = self .fetch_launch_plan (ident .project , ident .domain , ident .name , ident .version )
14991530 flp .python_interface = entity .python_interface
15001531 return flp
@@ -2196,17 +2227,14 @@ def execute_local_task(
21962227 domain = domain or self ._default_domain ,
21972228 version = version ,
21982229 )
2199- version , pickled_target_dict = self ._resolve_version (version , entity , ss )
2230+ version , _ = self ._resolve_version (version , entity , ss )
22002231
22012232 resolved_identifiers = self ._resolve_identifier_kwargs (entity , project , domain , name , version )
22022233 resolved_identifiers_dict = asdict (resolved_identifiers )
22032234 try :
22042235 flyte_task : FlyteTask = self .fetch_task (** resolved_identifiers_dict )
22052236 flyte_task .python_interface = entity .python_interface
22062237 except FlyteEntityNotExistException :
2207- if self .interactive_mode_enabled :
2208- ss .fast_serialization_settings = self ._pickle_and_upload_entity (entity , pickled_target_dict )
2209-
22102238 flyte_task : FlyteTask = self .register_task (entity , ss , version )
22112239
22122240 return self .execute (
@@ -2280,16 +2308,8 @@ def execute_local_workflow(
22802308 domain = domain or self ._default_domain ,
22812309 version = version ,
22822310 )
2283- pickled_target_dict = None
22842311 if version is None and self .interactive_mode_enabled :
2285- md5_bytes , pickled_target_dict = _get_pickled_target_dict (entity )
2286- version = self ._version_from_hash (
2287- md5_bytes ,
2288- ss ,
2289- entity .python_interface .default_inputs_as_kwargs ,
2290- * FlyteRemote ._get_image_names (entity ),
2291- * FlyteRemote ._get_pod_template_hash (entity ),
2292- )
2312+ version , _ = self ._resolve_version (version , entity , ss )
22932313
22942314 resolved_identifiers = self ._resolve_identifier_kwargs (entity , project , domain , name , version )
22952315 resolved_identifiers_dict = asdict (resolved_identifiers )
@@ -2300,9 +2320,12 @@ def execute_local_workflow(
23002320 self .fetch_workflow (** resolved_identifiers_dict )
23012321 except FlyteEntityNotExistException :
23022322 logger .info ("Registering workflow because it wasn't found in Flyte Admin." )
2303- if self .interactive_mode_enabled :
2304- ss .fast_serialization_settings = self ._pickle_and_upload_entity (entity , pickled_target_dict )
2305- self .register_workflow (entity , ss , version = version , options = options )
2323+ self .register_workflow (
2324+ entity ,
2325+ ss ,
2326+ version = version ,
2327+ options = options ,
2328+ )
23062329
23072330 try :
23082331 flyte_lp = self .fetch_launch_plan (** resolved_identifiers_dict )
0 commit comments