Skip to content

Commit 2ac5750

Browse files
authored
[V2] Add attempt and cache status (#7078)
* feat: add attempts and cache status to CR Signed-off-by: machichima <nary12321@gmail.com> * feat: get attempts/cache status from object Signed-off-by: machichima <nary12321@gmail.com> * feat: store attempts/cache status to CR and send through action event Signed-off-by: machichima <nary12321@gmail.com> * feat: consider retry count in generated name Signed-off-by: machichima <nary12321@gmail.com> * feat: get max attempt from task template Signed-off-by: machichima <nary12321@gmail.com> * feat: run service save attempts/cache status to DB Signed-off-by: machichima <nary12321@gmail.com> * build: make sandbox-build Signed-off-by: machichima <nary12321@gmail.com> --------- Signed-off-by: machichima <nary12321@gmail.com>
1 parent dfe433d commit 2ac5750

File tree

14 files changed

+228
-35
lines changed

14 files changed

+228
-35
lines changed

actions/k8s/client.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ func (c *ActionsClient) GetState(ctx context.Context, actionID *common.ActionIde
189189
return taskAction.Status.StateJSON, nil
190190
}
191191

192-
// PutState updates the state JSON for a TaskAction.
193-
// attempt and status are accepted for future use (e.g. recording to RunService).
192+
// PutState updates the state JSON and latest attempt metadata for a TaskAction.
194193
func (c *ActionsClient) PutState(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32, status *workflow.ActionStatus, stateJSON string) error {
195194
taskActionName := buildTaskActionName(actionID)
196195

@@ -210,6 +209,10 @@ func (c *ActionsClient) PutState(ctx context.Context, actionID *common.ActionIde
210209

211210
// Update state JSON
212211
taskAction.Status.StateJSON = stateJSON
212+
if status != nil {
213+
taskAction.Status.Attempts = status.GetAttempts()
214+
taskAction.Status.CacheStatus = status.GetCacheStatus()
215+
}
213216

214217
// Update status subresource
215218
if err := c.k8sClient.Status().Update(ctx, taskAction); err != nil {
@@ -499,7 +502,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut
499502
statusReq := &workflow.UpdateActionStatusRequest{
500503
ActionId: update.ActionID,
501504
Status: &workflow.ActionStatus{
502-
Phase: update.Phase,
505+
Phase: update.Phase,
506+
Attempts: taskAction.Status.Attempts,
507+
CacheStatus: taskAction.Status.CacheStatus,
503508
},
504509
}
505510
if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil {
@@ -548,15 +553,14 @@ func GetPhaseFromConditions(taskAction *executorv1.TaskAction) common.ActionPhas
548553
}
549554

550555
// buildTaskActionName generates a Kubernetes-compliant name for the TaskAction.
551-
// For root actions (where action name == run name), the name is <run-id>-a0-0.
552-
// For child actions, the name is <run-id>-<action-id>-0.
553-
// The trailing "0" is the attempt number (0-indexed; hardcoded until retry support is added).
556+
// For root actions (where action name == run name), the name is <run-id>-a0.
557+
// For child actions, the name is <run-id>-<action-id>.
554558
func buildTaskActionName(actionID *common.ActionIdentifier) string {
555559
isRoot := actionID.Name == actionID.Run.Name
556560
if isRoot {
557-
return fmt.Sprintf("%s-a0-0", actionID.Run.Name)
561+
return fmt.Sprintf("%s-a0", actionID.Run.Name)
558562
}
559-
return fmt.Sprintf("%s-%s-0", actionID.Run.Name, actionID.Name)
563+
return fmt.Sprintf("%s-%s", actionID.Run.Name, actionID.Name)
560564
}
561565

562566
// buildNamespace returns the Kubernetes namespace for a run: "<project>-<domain>".

actions/k8s/client_test.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/flyteorg/flyte/v2/flytestdlib/fastcheck"
1515
"github.com/flyteorg/flyte/v2/flytestdlib/promutils"
1616
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
17+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
1718
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
1819
runmocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect/mocks"
1920
)
@@ -130,6 +131,32 @@ func TestNotifyRunService_NilFilter(t *testing.T) {
130131
mockClient.AssertNumberOfCalls(t, "RecordAction", 2)
131132
}
132133

134+
func TestNotifyRunService_UpdateActionStatusIncludesAttemptsAndCacheStatus(t *testing.T) {
135+
ctx := context.Background()
136+
137+
mockClient := runmocks.NewInternalRunServiceClient(t)
138+
c := &ActionsClient{
139+
runClient: mockClient,
140+
subscribers: make(map[string]map[chan *ActionUpdate]struct{}),
141+
}
142+
143+
ta, update := newTestActionUpdate("action-4")
144+
ta.Status.Attempts = 3
145+
ta.Status.CacheStatus = core.CatalogCacheStatus_CACHE_HIT
146+
update.Phase = common.ActionPhase_ACTION_PHASE_SUCCEEDED
147+
148+
mockClient.On("UpdateActionStatus", mock.Anything, mock.MatchedBy(func(req *connect.Request[workflow.UpdateActionStatusRequest]) bool {
149+
status := req.Msg.GetStatus()
150+
return status.GetPhase() == common.ActionPhase_ACTION_PHASE_SUCCEEDED &&
151+
status.GetAttempts() == 3 &&
152+
status.GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT
153+
})).Return(&connect.Response[workflow.UpdateActionStatusResponse]{}, nil).Once()
154+
155+
c.notifyRunService(ctx, ta, update, watch.Modified)
156+
157+
mockClient.AssertNumberOfCalls(t, "UpdateActionStatus", 1)
158+
}
159+
133160
func TestBuildTaskActionName(t *testing.T) {
134161
runID := &common.RunIdentifier{
135162
Org: "org",
@@ -138,21 +165,21 @@ func TestBuildTaskActionName(t *testing.T) {
138165
Name: "rabc123",
139166
}
140167

141-
t.Run("root action uses a0-0 suffix", func(t *testing.T) {
168+
t.Run("root action uses a0 suffix", func(t *testing.T) {
142169
// Root: action name == run name
143170
actionID := &common.ActionIdentifier{
144171
Run: runID,
145172
Name: runID.Name,
146173
}
147-
assert.Equal(t, "rabc123-a0-0", buildTaskActionName(actionID))
174+
assert.Equal(t, "rabc123-a0", buildTaskActionName(actionID))
148175
})
149176

150177
t.Run("child action includes action name", func(t *testing.T) {
151178
actionID := &common.ActionIdentifier{
152179
Run: runID,
153180
Name: "train",
154181
}
155-
assert.Equal(t, "rabc123-train-0", buildTaskActionName(actionID))
182+
assert.Equal(t, "rabc123-train", buildTaskActionName(actionID))
156183
})
157184
}
158185

charts/flyte-binary/templates/crds/taskaction.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ spec:
231231
- phase
232232
type: object
233233
type: array
234+
attempts:
235+
description: Attempts is the latest observed action attempt number,
236+
starting from 1.
237+
type: integer
238+
cacheStatus:
239+
description: CacheStatus is the latest observed cache lookup result
240+
for this action.
241+
type: integer
234242
pluginPhase:
235243
description: PluginPhase is a human-readable representation of the
236244
plugin's current phase.

docker/sandbox-bundled/manifests/complete.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ spec:
141141
status:
142142
description: status defines the observed state of TaskAction
143143
properties:
144+
attempts:
145+
description: Attempts is the latest observed action attempt number,
146+
starting from 1.
147+
type: integer
148+
cacheStatus:
149+
description: CacheStatus is the latest observed cache lookup result
150+
for this action.
151+
type: integer
144152
conditions:
145153
description: |-
146154
conditions represent the current state of the TaskAction resource.
@@ -1049,7 +1057,7 @@ type: Opaque
10491057
---
10501058
apiVersion: v1
10511059
data:
1052-
haSharedSecret: RGhWVnFVakdvTU93amsxbA==
1060+
haSharedSecret: bHZNNWFFbjVJYzlZSHJaQg==
10531061
proxyPassword: ""
10541062
proxyUsername: ""
10551063
kind: Secret
@@ -1575,7 +1583,7 @@ spec:
15751583
metadata:
15761584
annotations:
15771585
checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81
1578-
checksum/secret: d75874ebdfdca751387f7a272a14ca6f600f03189ebda404575c1645aa7e399a
1586+
checksum/secret: db8101eb87292a64110c1efe3ac20b104e11ec74eeecdadd0641e90529028080
15791587
labels:
15801588
app: docker-registry
15811589
release: flyte-sandbox

docker/sandbox-bundled/manifests/dev.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ spec:
141141
status:
142142
description: status defines the observed state of TaskAction
143143
properties:
144+
attempts:
145+
description: Attempts is the latest observed action attempt number,
146+
starting from 1.
147+
type: integer
148+
cacheStatus:
149+
description: CacheStatus is the latest observed cache lookup result
150+
for this action.
151+
type: integer
144152
conditions:
145153
description: |-
146154
conditions represent the current state of the TaskAction resource.
@@ -763,7 +771,7 @@ metadata:
763771
---
764772
apiVersion: v1
765773
data:
766-
haSharedSecret: TDY5Y2lPS0xUMFBRNUZLMQ==
774+
haSharedSecret: ZUdTa2I5SE8xNU5rUlZtdQ==
767775
proxyPassword: ""
768776
proxyUsername: ""
769777
kind: Secret
@@ -1192,7 +1200,7 @@ spec:
11921200
metadata:
11931201
annotations:
11941202
checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81
1195-
checksum/secret: e5c345361f6169f960277db5c368b1f1c2a7ea280b8e17a066ac1c369806b160
1203+
checksum/secret: f57b9b6e36305b7992e6f671cc05f984c1bc49359d77acdff1b13f761c03699c
11961204
labels:
11971205
app: docker-registry
11981206
release: flyte-sandbox

executor/api/v1/taskaction_types.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2121

2222
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
23+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
2324
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
2425
)
2526

@@ -225,6 +226,14 @@ type TaskActionStatus struct {
225226
// +optional
226227
PluginPhaseVersion uint32 `json:"pluginPhaseVersion,omitempty"`
227228

229+
// Attempts is the latest observed action attempt number, starting from 1.
230+
// +optional
231+
Attempts uint32 `json:"attempts,omitempty"`
232+
233+
// CacheStatus is the latest observed cache lookup result for this action.
234+
// +optional
235+
CacheStatus core.CatalogCacheStatus `json:"cacheStatus,omitempty"`
236+
228237
// conditions represent the current state of the TaskAction resource.
229238
// Each condition has a unique type and reflects the status of a specific aspect of the resource.
230239
//

executor/config/crd/bases/flyte.org_taskactions.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ spec:
231231
- phase
232232
type: object
233233
type: array
234+
attempts:
235+
description: Attempts is the latest observed action attempt number,
236+
starting from 1.
237+
type: integer
238+
cacheStatus:
239+
description: CacheStatus is the latest observed cache lookup result
240+
for this action.
241+
type: integer
234242
pluginPhase:
235243
description: PluginPhase is a human-readable representation of the
236244
plugin's current phase.

executor/pkg/controller/taskaction_controller.go

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ func (r *TaskActionReconciler) Reconcile(ctx context.Context, req ctrl.Request)
211211

212212
taskAction.Status.PluginPhase = phaseInfo.Phase().String()
213213
taskAction.Status.PluginPhaseVersion = phaseInfo.Version()
214+
taskAction.Status.Attempts = observedAttempts(taskAction)
215+
taskAction.Status.CacheStatus = observedCacheStatus(phaseInfo.Info())
214216

215217
if err := r.updateTaskActionStatus(ctx, originalTaskActionInstance, taskAction, phaseInfo); err != nil {
216218
return ctrl.Result{}, err
@@ -359,7 +361,7 @@ func (r *TaskActionReconciler) buildActionEvent(
359361

360362
event := &workflow.ActionEvent{
361363
Id: actionID,
362-
Attempt: 1, // TODO(nary): wire retry attempt once retry state is available in executor status.
364+
Attempt: observedAttempts(taskAction),
363365
Phase: phaseToActionPhase(phaseInfo.Phase()),
364366
Version: phaseInfo.Version(),
365367
UpdatedTime: updatedTime,
@@ -373,12 +375,27 @@ func (r *TaskActionReconciler) buildActionEvent(
373375
if info != nil {
374376
event.LogInfo = info.Logs
375377
event.LogContext = info.LogContext
376-
event.CacheStatus = cacheStatusFromExternalResources(info.ExternalResources)
377378
}
379+
event.CacheStatus = observedCacheStatus(info)
378380

379381
return event
380382
}
381383

384+
func observedAttempts(taskAction *flyteorgv1.TaskAction) uint32 {
385+
if taskAction.Status.Attempts > 0 {
386+
return taskAction.Status.Attempts
387+
}
388+
// if attempts is not set, default to 1
389+
return 1
390+
}
391+
392+
func observedCacheStatus(info *pluginsCore.TaskInfo) core.CatalogCacheStatus {
393+
if info == nil {
394+
return core.CatalogCacheStatus_CACHE_DISABLED
395+
}
396+
return cacheStatusFromExternalResources(info.ExternalResources)
397+
}
398+
382399
func updatedTimestamp(info *pluginsCore.TaskInfo, history []flyteorgv1.PhaseTransition) *timestamppb.Timestamp {
383400
if info != nil && info.OccurredAt != nil {
384401
return timestamppb.New(*info.OccurredAt)
@@ -481,7 +498,9 @@ func taskActionStatusChanged(oldStatus, newStatus flyteorgv1.TaskActionStatus) b
481498
if oldStatus.StateJSON != newStatus.StateJSON ||
482499
oldStatus.PluginStateVersion != newStatus.PluginStateVersion ||
483500
oldStatus.PluginPhase != newStatus.PluginPhase ||
484-
oldStatus.PluginPhaseVersion != newStatus.PluginPhaseVersion {
501+
oldStatus.PluginPhaseVersion != newStatus.PluginPhaseVersion ||
502+
oldStatus.Attempts != newStatus.Attempts ||
503+
oldStatus.CacheStatus != newStatus.CacheStatus {
485504
return true
486505
}
487506

executor/pkg/plugin/task_exec_metadata.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package plugin
22

33
import (
4+
"fmt"
5+
46
"google.golang.org/protobuf/proto"
57
v1 "k8s.io/api/core/v1"
68
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -80,14 +82,17 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut
8082
"_U_ORG_NAME": ta.Spec.Org,
8183
"_U_RUN_BASE": ta.Spec.RunOutputBase,
8284
}
85+
generatedName := buildGeneratedName(ta)
86+
retryAttempt := attemptToRetry(ta.Status.Attempts)
87+
maxAttempts := maxAttemptsFromTaskTemplate(ta.Spec.TaskTemplate)
8388

8489
return &taskExecutionMetadata{
8590
ownerID: types.NamespacedName{
8691
Name: ta.Name,
8792
Namespace: ta.Namespace,
8893
},
8994
taskExecutionID: &taskExecutionID{
90-
generatedName: ta.Name,
95+
generatedName: generatedName,
9196
id: core.TaskExecutionIdentifier{
9297
NodeExecutionId: &core.NodeExecutionIdentifier{
9398
ExecutionId: &core.WorkflowExecutionIdentifier{
@@ -98,6 +103,7 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut
98103
},
99104
NodeId: ta.Spec.ActionName,
100105
},
106+
RetryAttempt: retryAttempt,
101107
},
102108
},
103109
namespace: ta.Namespace,
@@ -109,13 +115,44 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut
109115
},
110116
labels: pluginsUtils.UnionMaps(ta.Labels, injectLabels),
111117
annotations: pluginsUtils.UnionMaps(ta.Annotations, secretsMap),
112-
maxAttempts: 1,
118+
maxAttempts: maxAttempts,
113119
overrides: overrides,
114120
envVars: envVars,
115121
securityContext: securityContext,
116122
}, nil
117123
}
118124

125+
func buildGeneratedName(ta *flyteorgv1.TaskAction) string {
126+
return fmt.Sprintf("%s-%d", ta.Name, attemptToRetry(ta.Status.Attempts))
127+
}
128+
129+
// attemptToRetry convert attempt to retry count
130+
func attemptToRetry(attempt uint32) uint32 {
131+
if attempt <= 1 {
132+
return 0
133+
}
134+
return attempt - 1
135+
}
136+
137+
// maxAttemptsFromTaskTemplate give the max attempts (retries + 1) from the task template.
138+
func maxAttemptsFromTaskTemplate(data []byte) uint32 {
139+
if len(data) == 0 {
140+
return 1
141+
}
142+
143+
tmpl := &core.TaskTemplate{}
144+
if err := proto.Unmarshal(data, tmpl); err != nil {
145+
return 1
146+
}
147+
148+
md := tmpl.GetMetadata()
149+
if md == nil || md.GetRetries() == nil {
150+
return 1
151+
}
152+
153+
return md.GetRetries().GetRetries() + 1
154+
}
155+
119156
// buildOverridesFromTaskTemplate deserializes the task template and extracts resource requirements.
120157
func buildOverridesFromTaskTemplate(data []byte) *taskOverrides {
121158
if len(data) == 0 {

0 commit comments

Comments
 (0)