Skip to content

Commit eb684d7

Browse files
authored
Recursively handle subworkflow in dynamic (#260)
* Recursively handle subworkflow in dynamic Signed-off-by: Hongxin Liang <[email protected]> * Sub in Sub IT Signed-off-by: Hongxin Liang <[email protected]> * Fix subworkflow collecting Signed-off-by: Hongxin Liang <[email protected]> --------- Signed-off-by: Hongxin Liang <[email protected]>
1 parent c892ceb commit eb684d7

File tree

3 files changed

+120
-36
lines changed

3 files changed

+120
-36
lines changed

flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,19 @@ public Output run(SdkWorkflowBuilder builder, Input input) {
8787
SdkTypes.nulls(),
8888
SdkTypes.nulls())
8989
.withUpstreamNode(hello));
90+
// subworkflow that contains another subworkflow
91+
SdkNode<WelcomeWorkflow.Output> greet =
92+
builder.apply(
93+
"greet",
94+
new SubWorkflow().withUpstreamNode(world),
95+
WelcomeWorkflow.Input.create(SdkBindingDataFactory.of("greet")));
9096
@Var SdkBindingData<Long> prev = SdkBindingDataFactory.of(0);
9197
@Var SdkBindingData<Long> value = SdkBindingDataFactory.of(1);
9298
for (int i = 2; i <= input.n().get(); i++) {
9399
SdkBindingData<Long> next =
94100
builder
95101
.apply(
96-
"fib-" + i, new SumTask().withUpstreamNode(world), SumInput.create(value, prev))
102+
"fib-" + i, new SumTask().withUpstreamNode(greet), SumInput.create(value, prev))
97103
.getOutputs();
98104
prev = value;
99105
value = next;

jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,16 @@ && checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) {
343343
}
344344

345345
@VisibleForTesting
346+
static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
347+
List<Node> nodes, Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows) {
348+
return collectSubWorkflows(nodes, allWorkflows, Function.identity());
349+
}
350+
346351
public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
347-
List<Node> rewrittenNodes, Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows) {
352+
List<Node> nodes,
353+
Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows,
354+
Function<List<Node>, List<Node>> nodesRewriter) {
355+
List<Node> rewrittenNodes = nodesRewriter.apply(nodes);
348356
return collectSubWorkflowIds(rewrittenNodes).stream()
349357
// all identifiers should be rewritten at this point
350358
.map(
@@ -366,7 +374,7 @@ public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
366374
}
367375

368376
Map<WorkflowIdentifier, WorkflowTemplate> nestedSubWorkflows =
369-
collectSubWorkflows(subWorkflow.nodes(), allWorkflows);
377+
collectSubWorkflows(subWorkflow.nodes(), allWorkflows, nodesRewriter);
370378

371379
return Stream.concat(
372380
Stream.of(Maps.immutableEntry(workflowId, subWorkflow)),
@@ -376,10 +384,10 @@ public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
376384
}
377385

378386
public static Map<TaskIdentifier, TaskTemplate> collectDynamicWorkflowTasks(
379-
List<Node> rewrittenNodes,
387+
List<Node> nodes,
380388
Map<TaskIdentifier, TaskTemplate> allTasks,
381389
Function<TaskIdentifier, TaskTemplate> remoteTaskTemplateFetcher) {
382-
return collectTaskIds(rewrittenNodes).stream()
390+
return collectTaskIds(nodes).stream()
383391
// all identifiers should be rewritten at this point
384392
.map(
385393
taskId ->

jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java

Lines changed: 101 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
import com.google.common.collect.ImmutableMap;
2626
import com.google.common.collect.Maps;
2727
import java.util.Collection;
28+
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.concurrent.Callable;
3132
import java.util.concurrent.ExecutorService;
3233
import java.util.concurrent.ForkJoinPool;
34+
import java.util.function.Function;
3335
import java.util.stream.Collectors;
3436
import org.flyte.api.v1.Binding;
3537
import org.flyte.api.v1.BindingData;
@@ -197,12 +199,12 @@ private void execute() {
197199
}
198200
}
199201

200-
static DynamicJobSpec rewrite(
202+
private static DynamicJobSpec rewrite(
201203
Config config,
202204
ExecutionConfig executionConfig,
203205
DynamicJobSpec spec,
204-
Map<TaskIdentifier, TaskTemplate> taskTemplates,
205-
Map<WorkflowIdentifier, WorkflowTemplate> workflowTemplates) {
206+
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
207+
Map<WorkflowIdentifier, WorkflowTemplate> allWorkflowTemplates) {
206208

207209
try (FlyteAdminClient flyteAdminClient =
208210
FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), null)) {
@@ -215,58 +217,126 @@ static DynamicJobSpec rewrite(
215217
.adminClient(flyteAdminClient)
216218
.build()
217219
.visitor();
220+
Function<List<Node>, List<Node>> nodesRewriter =
221+
nodes -> nodes.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList());
218222

219-
List<Node> rewrittenNodes =
220-
spec.nodes().stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList());
221-
222-
Map<WorkflowIdentifier, WorkflowTemplate> usedSubWorkflows =
223-
ProjectClosure.collectSubWorkflows(rewrittenNodes, workflowTemplates);
224-
225-
Map<TaskIdentifier, TaskTemplate> usedTaskTemplates =
226-
ProjectClosure.collectDynamicWorkflowTasks(
227-
rewrittenNodes, taskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id));
228-
229-
// FIXME one sub-workflow can use more sub-workflows, we should recursively collect used tasks
230-
// and workflows
223+
Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows =
224+
collectAllUsedSubWorkflows(
225+
spec.nodes(), allWorkflowTemplates, workflowNodeVisitor, nodesRewriter);
231226

232-
Map<WorkflowIdentifier, WorkflowTemplate> rewrittenUsedSubWorkflows =
233-
mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate);
227+
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates = new HashMap<>();
228+
List<Node> rewrittenNodes =
229+
collectAllUsedTaskTemplates(
230+
spec,
231+
allTaskTemplates,
232+
nodesRewriter,
233+
allUsedTaskTemplates,
234+
flyteAdminClient,
235+
allUsedSubWorkflows);
234236

235237
return spec.toBuilder()
236238
.nodes(rewrittenNodes)
237239
.subWorkflows(
238240
ImmutableMap.<WorkflowIdentifier, WorkflowTemplate>builder()
239241
.putAll(spec.subWorkflows())
240-
.putAll(rewrittenUsedSubWorkflows)
242+
.putAll(allUsedSubWorkflows)
241243
.build())
242244
.tasks(
243245
ImmutableMap.<TaskIdentifier, TaskTemplate>builder()
244246
.putAll(spec.tasks())
245-
.putAll(usedTaskTemplates)
247+
.putAll(allUsedTaskTemplates)
246248
.build())
247249
.build();
248250
}
249251
}
250252

253+
private static List<Node> collectAllUsedTaskTemplates(
254+
DynamicJobSpec spec,
255+
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
256+
Function<List<Node>, List<Node>> nodesRewriter,
257+
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates,
258+
FlyteAdminClient flyteAdminClient,
259+
Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows) {
260+
261+
Map<TaskIdentifier, TaskTemplate> cache = new HashMap<>();
262+
263+
// collect directly used task templates
264+
List<Node> rewrittenNodes =
265+
collectTaskTemplates(
266+
spec.nodes(),
267+
nodesRewriter,
268+
allUsedTaskTemplates,
269+
allTaskTemplates,
270+
flyteAdminClient,
271+
cache);
272+
273+
// collect task templates used by subworkflows
274+
allUsedSubWorkflows
275+
.values()
276+
.forEach(
277+
workflowTemplate ->
278+
collectTaskTemplates(
279+
workflowTemplate.nodes(),
280+
nodesRewriter,
281+
allUsedTaskTemplates,
282+
allTaskTemplates,
283+
flyteAdminClient,
284+
cache));
285+
286+
return rewrittenNodes;
287+
}
288+
289+
private static Map<WorkflowIdentifier, WorkflowTemplate> collectAllUsedSubWorkflows(
290+
List<Node> nodes,
291+
Map<WorkflowIdentifier, WorkflowTemplate> workflowTemplates,
292+
WorkflowNodeVisitor workflowNodeVisitor,
293+
Function<List<Node>, List<Node>> nodesRewriter) {
294+
295+
Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows =
296+
ProjectClosure.collectSubWorkflows(nodes, workflowTemplates, nodesRewriter);
297+
return mapValues(allUsedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate);
298+
}
299+
300+
private static List<Node> collectTaskTemplates(
301+
List<Node> nodes,
302+
Function<List<Node>, List<Node>> nodesRewriter,
303+
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates,
304+
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
305+
FlyteAdminClient flyteAdminClient,
306+
Map<TaskIdentifier, TaskTemplate> cache) {
307+
308+
List<Node> rewrittenNodes = nodesRewriter.apply(nodes);
309+
310+
Map<TaskIdentifier, TaskTemplate> usedTaskTemplates =
311+
ProjectClosure.collectDynamicWorkflowTasks(
312+
rewrittenNodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache));
313+
allUsedTaskTemplates.putAll(usedTaskTemplates);
314+
315+
return rewrittenNodes;
316+
}
317+
251318
// note that there are cases we are making an unnecessary network call because we might have
252319
// already got the task template when resolving the latest task version, but since it is also
253320
// possible that user has provided a version for a remote task, and in that case we would not need
254321
// to resolve the latest version, so we need to make this call;
255322
// we accept the additional cost because it should be rare to have remote tasks in a dynamic
256323
// workflow
257324
private static TaskTemplate fetchTaskTemplate(
258-
FlyteAdminClient flyteAdminClient, TaskIdentifier id) {
259-
LOG.info("fetching task template remotely for {}", id);
260-
261-
TaskTemplate taskTemplate =
262-
flyteAdminClient.fetchLatestTaskTemplate(
263-
NamedEntityIdentifier.builder()
264-
.domain(id.domain())
265-
.project(id.project())
266-
.name(id.name())
267-
.build());
268-
269-
return taskTemplate;
325+
FlyteAdminClient flyteAdminClient,
326+
TaskIdentifier id,
327+
Map<TaskIdentifier, TaskTemplate> cache) {
328+
return cache.computeIfAbsent(
329+
id,
330+
taskIdentifier -> {
331+
LOG.info("fetching task template remotely for {}", id);
332+
333+
return flyteAdminClient.fetchLatestTaskTemplate(
334+
NamedEntityIdentifier.builder()
335+
.domain(id.domain())
336+
.project(id.project())
337+
.name(id.name())
338+
.build());
339+
});
270340
}
271341

272342
private static DynamicWorkflowTask getDynamicWorkflowTask(String name) {

0 commit comments

Comments
 (0)