2525import com .google .common .collect .ImmutableMap ;
2626import com .google .common .collect .Maps ;
2727import java .util .Collection ;
28+ import java .util .HashMap ;
2829import java .util .List ;
2930import java .util .Map ;
3031import java .util .concurrent .Callable ;
3132import java .util .concurrent .ExecutorService ;
3233import java .util .concurrent .ForkJoinPool ;
34+ import java .util .function .Function ;
3335import java .util .stream .Collectors ;
3436import org .flyte .api .v1 .Binding ;
3537import org .flyte .api .v1 .BindingData ;
@@ -197,12 +199,12 @@ private void execute() {
197199 }
198200 }
199201
200- static DynamicJobSpec rewrite (
202+ private static DynamicJobSpec rewrite (
201203 Config config ,
202204 ExecutionConfig executionConfig ,
203205 DynamicJobSpec spec ,
204- Map <TaskIdentifier , TaskTemplate > taskTemplates ,
205- Map <WorkflowIdentifier , WorkflowTemplate > workflowTemplates ) {
206+ Map <TaskIdentifier , TaskTemplate > allTaskTemplates ,
207+ Map <WorkflowIdentifier , WorkflowTemplate > allWorkflowTemplates ) {
206208
207209 try (FlyteAdminClient flyteAdminClient =
208210 FlyteAdminClient .create (config .platformUrl (), config .platformInsecure (), null )) {
@@ -215,58 +217,126 @@ static DynamicJobSpec rewrite(
215217 .adminClient (flyteAdminClient )
216218 .build ()
217219 .visitor ();
220+ Function <List <Node >, List <Node >> nodesRewriter =
221+ nodes -> nodes .stream ().map (workflowNodeVisitor ::visitNode ).collect (toUnmodifiableList ());
218222
219- List <Node > rewrittenNodes =
220- spec .nodes ().stream ().map (workflowNodeVisitor ::visitNode ).collect (toUnmodifiableList ());
221-
222- Map <WorkflowIdentifier , WorkflowTemplate > usedSubWorkflows =
223- ProjectClosure .collectSubWorkflows (rewrittenNodes , workflowTemplates );
224-
225- Map <TaskIdentifier , TaskTemplate > usedTaskTemplates =
226- ProjectClosure .collectDynamicWorkflowTasks (
227- rewrittenNodes , taskTemplates , id -> fetchTaskTemplate (flyteAdminClient , id ));
228-
229- // FIXME one sub-workflow can use more sub-workflows, we should recursively collect used tasks
230- // and workflows
223+ Map <WorkflowIdentifier , WorkflowTemplate > allUsedSubWorkflows =
224+ collectAllUsedSubWorkflows (
225+ spec .nodes (), allWorkflowTemplates , workflowNodeVisitor , nodesRewriter );
231226
232- Map <WorkflowIdentifier , WorkflowTemplate > rewrittenUsedSubWorkflows =
233- mapValues (usedSubWorkflows , workflowNodeVisitor ::visitWorkflowTemplate );
227+ Map <TaskIdentifier , TaskTemplate > allUsedTaskTemplates = new HashMap <>();
228+ List <Node > rewrittenNodes =
229+ collectAllUsedTaskTemplates (
230+ spec ,
231+ allTaskTemplates ,
232+ nodesRewriter ,
233+ allUsedTaskTemplates ,
234+ flyteAdminClient ,
235+ allUsedSubWorkflows );
234236
235237 return spec .toBuilder ()
236238 .nodes (rewrittenNodes )
237239 .subWorkflows (
238240 ImmutableMap .<WorkflowIdentifier , WorkflowTemplate >builder ()
239241 .putAll (spec .subWorkflows ())
240- .putAll (rewrittenUsedSubWorkflows )
242+ .putAll (allUsedSubWorkflows )
241243 .build ())
242244 .tasks (
243245 ImmutableMap .<TaskIdentifier , TaskTemplate >builder ()
244246 .putAll (spec .tasks ())
245- .putAll (usedTaskTemplates )
247+ .putAll (allUsedTaskTemplates )
246248 .build ())
247249 .build ();
248250 }
249251 }
250252
253+ private static List <Node > collectAllUsedTaskTemplates (
254+ DynamicJobSpec spec ,
255+ Map <TaskIdentifier , TaskTemplate > allTaskTemplates ,
256+ Function <List <Node >, List <Node >> nodesRewriter ,
257+ Map <TaskIdentifier , TaskTemplate > allUsedTaskTemplates ,
258+ FlyteAdminClient flyteAdminClient ,
259+ Map <WorkflowIdentifier , WorkflowTemplate > allUsedSubWorkflows ) {
260+
261+ Map <TaskIdentifier , TaskTemplate > cache = new HashMap <>();
262+
263+ // collect directly used task templates
264+ List <Node > rewrittenNodes =
265+ collectTaskTemplates (
266+ spec .nodes (),
267+ nodesRewriter ,
268+ allUsedTaskTemplates ,
269+ allTaskTemplates ,
270+ flyteAdminClient ,
271+ cache );
272+
273+ // collect task templates used by subworkflows
274+ allUsedSubWorkflows
275+ .values ()
276+ .forEach (
277+ workflowTemplate ->
278+ collectTaskTemplates (
279+ workflowTemplate .nodes (),
280+ nodesRewriter ,
281+ allUsedTaskTemplates ,
282+ allTaskTemplates ,
283+ flyteAdminClient ,
284+ cache ));
285+
286+ return rewrittenNodes ;
287+ }
288+
289+ private static Map <WorkflowIdentifier , WorkflowTemplate > collectAllUsedSubWorkflows (
290+ List <Node > nodes ,
291+ Map <WorkflowIdentifier , WorkflowTemplate > workflowTemplates ,
292+ WorkflowNodeVisitor workflowNodeVisitor ,
293+ Function <List <Node >, List <Node >> nodesRewriter ) {
294+
295+ Map <WorkflowIdentifier , WorkflowTemplate > allUsedSubWorkflows =
296+ ProjectClosure .collectSubWorkflows (nodes , workflowTemplates , nodesRewriter );
297+ return mapValues (allUsedSubWorkflows , workflowNodeVisitor ::visitWorkflowTemplate );
298+ }
299+
300+ private static List <Node > collectTaskTemplates (
301+ List <Node > nodes ,
302+ Function <List <Node >, List <Node >> nodesRewriter ,
303+ Map <TaskIdentifier , TaskTemplate > allUsedTaskTemplates ,
304+ Map <TaskIdentifier , TaskTemplate > allTaskTemplates ,
305+ FlyteAdminClient flyteAdminClient ,
306+ Map <TaskIdentifier , TaskTemplate > cache ) {
307+
308+ List <Node > rewrittenNodes = nodesRewriter .apply (nodes );
309+
310+ Map <TaskIdentifier , TaskTemplate > usedTaskTemplates =
311+ ProjectClosure .collectDynamicWorkflowTasks (
312+ rewrittenNodes , allTaskTemplates , id -> fetchTaskTemplate (flyteAdminClient , id , cache ));
313+ allUsedTaskTemplates .putAll (usedTaskTemplates );
314+
315+ return rewrittenNodes ;
316+ }
317+
251318 // note that there are cases we are making an unnecessary network call because we might have
252319 // already got the task template when resolving the latest task version, but since it is also
253320 // possible that user has provided a version for a remote task, and in that case we would not need
254321 // to resolve the latest version, so we need to make this call;
255322 // we accept the additional cost because it should be rare to have remote tasks in a dynamic
256323 // workflow
257324 private static TaskTemplate fetchTaskTemplate (
258- FlyteAdminClient flyteAdminClient , TaskIdentifier id ) {
259- LOG .info ("fetching task template remotely for {}" , id );
260-
261- TaskTemplate taskTemplate =
262- flyteAdminClient .fetchLatestTaskTemplate (
263- NamedEntityIdentifier .builder ()
264- .domain (id .domain ())
265- .project (id .project ())
266- .name (id .name ())
267- .build ());
268-
269- return taskTemplate ;
325+ FlyteAdminClient flyteAdminClient ,
326+ TaskIdentifier id ,
327+ Map <TaskIdentifier , TaskTemplate > cache ) {
328+ return cache .computeIfAbsent (
329+ id ,
330+ taskIdentifier -> {
331+ LOG .info ("fetching task template remotely for {}" , id );
332+
333+ return flyteAdminClient .fetchLatestTaskTemplate (
334+ NamedEntityIdentifier .builder ()
335+ .domain (id .domain ())
336+ .project (id .project ())
337+ .name (id .name ())
338+ .build ());
339+ });
270340 }
271341
272342 private static DynamicWorkflowTask getDynamicWorkflowTask (String name ) {
0 commit comments