@@ -148,6 +148,7 @@ type ExecutionConfig struct {
148
148
149
149
// DAGExecution custom properties
150
150
IterationCount * int // Number of iterations for an iterator DAG.
151
+ TotalDagTasks * int // Number of tasks inside the DAG
151
152
}
152
153
153
154
// InputArtifact is a wrapper around an MLMD artifact used as component inputs.
@@ -526,6 +527,7 @@ const (
526
527
keyParentDagID = "parent_dag_id" // Parent DAG Execution ID.
527
528
keyIterationIndex = "iteration_index"
528
529
keyIterationCount = "iteration_count"
530
+ keyTotalDagTasks = "total_dag_tasks"
529
531
)
530
532
531
533
// CreateExecution creates a new MLMD execution under the specified Pipeline.
@@ -620,6 +622,9 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
620
622
}
621
623
e .CustomProperties [keyArtifactProducerTask ] = StringValue (string (b ))
622
624
}
625
+ if config .TotalDagTasks != nil {
626
+ e .CustomProperties [keyTotalDagTasks ] = intValue (int64 (* config .TotalDagTasks ))
627
+ }
623
628
624
629
req := & pb.PutExecutionRequest {
625
630
Execution : e ,
@@ -690,11 +695,13 @@ func (c *Client) UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipelin
690
695
if err != nil {
691
696
return err
692
697
}
698
+
699
+ totalDagTasks := dag .Execution .execution .CustomProperties ["total_dag_tasks" ].GetIntValue ()
700
+
693
701
glog .V (4 ).Infof ("tasks: %v" , tasks )
694
702
glog .V (4 ).Infof ("Checking Tasks' State" )
695
703
completedTasks := 0
696
704
failedTasks := 0
697
- totalTasks := len (tasks )
698
705
for _ , task := range tasks {
699
706
taskState := task .GetExecution ().LastKnownState .String ()
700
707
glog .V (4 ).Infof ("task: %s" , task .TaskName ())
@@ -712,10 +719,10 @@ func (c *Client) UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipelin
712
719
}
713
720
glog .V (4 ).Infof ("completedTasks: %d" , completedTasks )
714
721
glog .V (4 ).Infof ("failedTasks: %d" , failedTasks )
715
- glog .V (4 ).Infof ("totalTasks: %d" , totalTasks )
722
+ glog .V (4 ).Infof ("totalTasks: %d" , totalDagTasks )
716
723
717
724
glog .Infof ("Attempting to update DAG state" )
718
- if completedTasks == totalTasks {
725
+ if completedTasks == int ( totalDagTasks ) {
719
726
c .PutDAGExecutionState (ctx , dag .Execution .GetID (), pb .Execution_COMPLETE )
720
727
} else if failedTasks > 0 {
721
728
c .PutDAGExecutionState (ctx , dag .Execution .GetID (), pb .Execution_FAILED )
0 commit comments