Skip to content

Commit 7719b38

Browse files
fix(backend): fixes DAG status update to reflect completion of all tasks (kubeflow#11651)
Signed-off-by: VaniHaripriya <[email protected]>
1 parent cc1c435 commit 7719b38

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

backend/src/v2/driver/driver.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,10 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
979979
ecfg.OutputArtifacts = opts.Component.GetDag().GetOutputs().GetArtifacts()
980980
glog.V(4).Info("outputArtifacts: ", ecfg.OutputArtifacts)
981981

982+
totalDagTasks := len(opts.Component.GetDag().GetTasks())
983+
ecfg.TotalDagTasks = &totalDagTasks
984+
glog.V(4).Info("totalDagTasks: ", *ecfg.TotalDagTasks)
985+
982986
if opts.Task.GetArtifactIterator() != nil {
983987
return execution, fmt.Errorf("ArtifactIterator is not implemented")
984988
}

backend/src/v2/metadata/client.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ type ExecutionConfig struct {
148148

149149
// DAGExecution custom properties
150150
IterationCount *int // Number of iterations for an iterator DAG.
151+
TotalDagTasks *int // Number of tasks inside the DAG
151152
}
152153

153154
// InputArtifact is a wrapper around an MLMD artifact used as component inputs.
@@ -526,6 +527,7 @@ const (
526527
keyParentDagID = "parent_dag_id" // Parent DAG Execution ID.
527528
keyIterationIndex = "iteration_index"
528529
keyIterationCount = "iteration_count"
530+
keyTotalDagTasks = "total_dag_tasks"
529531
)
530532

531533
// CreateExecution creates a new MLMD execution under the specified Pipeline.
@@ -620,6 +622,9 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
620622
}
621623
e.CustomProperties[keyArtifactProducerTask] = StringValue(string(b))
622624
}
625+
if config.TotalDagTasks != nil {
626+
e.CustomProperties[keyTotalDagTasks] = intValue(int64(*config.TotalDagTasks))
627+
}
623628

624629
req := &pb.PutExecutionRequest{
625630
Execution: e,
@@ -690,11 +695,13 @@ func (c *Client) UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipelin
690695
if err != nil {
691696
return err
692697
}
698+
699+
totalDagTasks := dag.Execution.execution.CustomProperties["total_dag_tasks"].GetIntValue()
700+
693701
glog.V(4).Infof("tasks: %v", tasks)
694702
glog.V(4).Infof("Checking Tasks' State")
695703
completedTasks := 0
696704
failedTasks := 0
697-
totalTasks := len(tasks)
698705
for _, task := range tasks {
699706
taskState := task.GetExecution().LastKnownState.String()
700707
glog.V(4).Infof("task: %s", task.TaskName())
@@ -712,10 +719,10 @@ func (c *Client) UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipelin
712719
}
713720
glog.V(4).Infof("completedTasks: %d", completedTasks)
714721
glog.V(4).Infof("failedTasks: %d", failedTasks)
715-
glog.V(4).Infof("totalTasks: %d", totalTasks)
722+
glog.V(4).Infof("totalTasks: %d", totalDagTasks)
716723

717724
glog.Infof("Attempting to update DAG state")
718-
if completedTasks == totalTasks {
725+
if completedTasks == int(totalDagTasks) {
719726
c.PutDAGExecutionState(ctx, dag.Execution.GetID(), pb.Execution_COMPLETE)
720727
} else if failedTasks > 0 {
721728
c.PutDAGExecutionState(ctx, dag.Execution.GetID(), pb.Execution_FAILED)

0 commit comments

Comments
 (0)