Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions actions/k8s/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ func (c *ActionsClient) GetState(ctx context.Context, actionID *common.ActionIde
return taskAction.Status.StateJSON, nil
}

// PutState updates the state JSON for a TaskAction.
// attempt and status are accepted for future use (e.g. recording to RunService).
// PutState updates the state JSON and latest attempt metadata for a TaskAction.
func (c *ActionsClient) PutState(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32, status *workflow.ActionStatus, stateJSON string) error {
taskActionName := buildTaskActionName(actionID)

Expand All @@ -210,6 +209,10 @@ func (c *ActionsClient) PutState(ctx context.Context, actionID *common.ActionIde

// Update state JSON
taskAction.Status.StateJSON = stateJSON
if status != nil {
taskAction.Status.Attempts = status.GetAttempts()
taskAction.Status.CacheStatus = status.GetCacheStatus()
}

// Update status subresource
if err := c.k8sClient.Status().Update(ctx, taskAction); err != nil {
Expand Down Expand Up @@ -499,7 +502,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut
statusReq := &workflow.UpdateActionStatusRequest{
ActionId: update.ActionID,
Status: &workflow.ActionStatus{
Phase: update.Phase,
Phase: update.Phase,
Attempts: taskAction.Status.Attempts,
CacheStatus: taskAction.Status.CacheStatus,
},
}
if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil {
Expand Down Expand Up @@ -548,15 +553,14 @@ func GetPhaseFromConditions(taskAction *executorv1.TaskAction) common.ActionPhas
}

// buildTaskActionName generates a Kubernetes-compliant name for the TaskAction.
// For root actions (where action name == run name), the name is <run-id>-a0-0.
// For child actions, the name is <run-id>-<action-id>-0.
// The trailing "0" is the attempt number (0-indexed; hardcoded until retry support is added).
// For root actions (where action name == run name), the name is <run-id>-a0.
// For child actions, the name is <run-id>-<action-id>.
func buildTaskActionName(actionID *common.ActionIdentifier) string {
isRoot := actionID.Name == actionID.Run.Name
if isRoot {
return fmt.Sprintf("%s-a0-0", actionID.Run.Name)
return fmt.Sprintf("%s-a0", actionID.Run.Name)
}
return fmt.Sprintf("%s-%s-0", actionID.Run.Name, actionID.Name)
return fmt.Sprintf("%s-%s", actionID.Run.Name, actionID.Name)
}

// buildNamespace returns the Kubernetes namespace for a run: "<project>-<domain>".
Expand Down
33 changes: 30 additions & 3 deletions actions/k8s/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/flyteorg/flyte/v2/flytestdlib/fastcheck"
"github.com/flyteorg/flyte/v2/flytestdlib/promutils"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
runmocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect/mocks"
)
Expand Down Expand Up @@ -130,6 +131,32 @@ func TestNotifyRunService_NilFilter(t *testing.T) {
mockClient.AssertNumberOfCalls(t, "RecordAction", 2)
}

func TestNotifyRunService_UpdateActionStatusIncludesAttemptsAndCacheStatus(t *testing.T) {
ctx := context.Background()

mockClient := runmocks.NewInternalRunServiceClient(t)
c := &ActionsClient{
runClient: mockClient,
subscribers: make(map[string]map[chan *ActionUpdate]struct{}),
}

ta, update := newTestActionUpdate("action-4")
ta.Status.Attempts = 3
ta.Status.CacheStatus = core.CatalogCacheStatus_CACHE_HIT
update.Phase = common.ActionPhase_ACTION_PHASE_SUCCEEDED

mockClient.On("UpdateActionStatus", mock.Anything, mock.MatchedBy(func(req *connect.Request[workflow.UpdateActionStatusRequest]) bool {
status := req.Msg.GetStatus()
return status.GetPhase() == common.ActionPhase_ACTION_PHASE_SUCCEEDED &&
status.GetAttempts() == 3 &&
status.GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT
})).Return(&connect.Response[workflow.UpdateActionStatusResponse]{}, nil).Once()

c.notifyRunService(ctx, ta, update, watch.Modified)

mockClient.AssertNumberOfCalls(t, "UpdateActionStatus", 1)
}

func TestBuildTaskActionName(t *testing.T) {
runID := &common.RunIdentifier{
Org: "org",
Expand All @@ -138,21 +165,21 @@ func TestBuildTaskActionName(t *testing.T) {
Name: "rabc123",
}

t.Run("root action uses a0-0 suffix", func(t *testing.T) {
t.Run("root action uses a0 suffix", func(t *testing.T) {
// Root: action name == run name
actionID := &common.ActionIdentifier{
Run: runID,
Name: runID.Name,
}
assert.Equal(t, "rabc123-a0-0", buildTaskActionName(actionID))
assert.Equal(t, "rabc123-a0", buildTaskActionName(actionID))
})

t.Run("child action includes action name", func(t *testing.T) {
actionID := &common.ActionIdentifier{
Run: runID,
Name: "train",
}
assert.Equal(t, "rabc123-train-0", buildTaskActionName(actionID))
assert.Equal(t, "rabc123-train", buildTaskActionName(actionID))
})
}

Expand Down
8 changes: 8 additions & 0 deletions charts/flyte-binary/templates/crds/taskaction.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ spec:
- phase
type: object
type: array
attempts:
description: Attempts is the latest observed action attempt number,
starting from 1.
type: integer
cacheStatus:
description: CacheStatus is the latest observed cache lookup result
for this action.
type: integer
pluginPhase:
description: PluginPhase is a human-readable representation of the
plugin's current phase.
Expand Down
12 changes: 10 additions & 2 deletions docker/sandbox-bundled/manifests/complete.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ spec:
status:
description: status defines the observed state of TaskAction
properties:
attempts:
description: Attempts is the latest observed action attempt number,
starting from 1.
type: integer
cacheStatus:
description: CacheStatus is the latest observed cache lookup result
for this action.
type: integer
conditions:
description: |-
conditions represent the current state of the TaskAction resource.
Expand Down Expand Up @@ -1049,7 +1057,7 @@ type: Opaque
---
apiVersion: v1
data:
haSharedSecret: RGhWVnFVakdvTU93amsxbA==
haSharedSecret: bHZNNWFFbjVJYzlZSHJaQg==
proxyPassword: ""
proxyUsername: ""
kind: Secret
Expand Down Expand Up @@ -1575,7 +1583,7 @@ spec:
metadata:
annotations:
checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81
checksum/secret: d75874ebdfdca751387f7a272a14ca6f600f03189ebda404575c1645aa7e399a
checksum/secret: db8101eb87292a64110c1efe3ac20b104e11ec74eeecdadd0641e90529028080
labels:
app: docker-registry
release: flyte-sandbox
Expand Down
12 changes: 10 additions & 2 deletions docker/sandbox-bundled/manifests/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ spec:
status:
description: status defines the observed state of TaskAction
properties:
attempts:
description: Attempts is the latest observed action attempt number,
starting from 1.
type: integer
cacheStatus:
description: CacheStatus is the latest observed cache lookup result
for this action.
type: integer
conditions:
description: |-
conditions represent the current state of the TaskAction resource.
Expand Down Expand Up @@ -763,7 +771,7 @@ metadata:
---
apiVersion: v1
data:
haSharedSecret: TDY5Y2lPS0xUMFBRNUZLMQ==
haSharedSecret: ZUdTa2I5SE8xNU5rUlZtdQ==
proxyPassword: ""
proxyUsername: ""
kind: Secret
Expand Down Expand Up @@ -1192,7 +1200,7 @@ spec:
metadata:
annotations:
checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81
checksum/secret: e5c345361f6169f960277db5c368b1f1c2a7ea280b8e17a066ac1c369806b160
checksum/secret: f57b9b6e36305b7992e6f671cc05f984c1bc49359d77acdff1b13f761c03699c
labels:
app: docker-registry
release: flyte-sandbox
Expand Down
9 changes: 9 additions & 0 deletions executor/api/v1/taskaction_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
)

Expand Down Expand Up @@ -225,6 +226,14 @@ type TaskActionStatus struct {
// +optional
PluginPhaseVersion uint32 `json:"pluginPhaseVersion,omitempty"`

// Attempts is the latest observed action attempt number, starting from 1.
// +optional
Attempts uint32 `json:"attempts,omitempty"`

// CacheStatus is the latest observed cache lookup result for this action.
// +optional
CacheStatus core.CatalogCacheStatus `json:"cacheStatus,omitempty"`

// conditions represent the current state of the TaskAction resource.
// Each condition has a unique type and reflects the status of a specific aspect of the resource.
//
Expand Down
8 changes: 8 additions & 0 deletions executor/config/crd/bases/flyte.org_taskactions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ spec:
- phase
type: object
type: array
attempts:
description: Attempts is the latest observed action attempt number,
starting from 1.
type: integer
cacheStatus:
description: CacheStatus is the latest observed cache lookup result
for this action.
type: integer
pluginPhase:
description: PluginPhase is a human-readable representation of the
plugin's current phase.
Expand Down
25 changes: 22 additions & 3 deletions executor/pkg/controller/taskaction_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ func (r *TaskActionReconciler) Reconcile(ctx context.Context, req ctrl.Request)

taskAction.Status.PluginPhase = phaseInfo.Phase().String()
taskAction.Status.PluginPhaseVersion = phaseInfo.Version()
taskAction.Status.Attempts = observedAttempts(taskAction)
taskAction.Status.CacheStatus = observedCacheStatus(phaseInfo.Info())

if err := r.updateTaskActionStatus(ctx, originalTaskActionInstance, taskAction, phaseInfo); err != nil {
return ctrl.Result{}, err
Expand Down Expand Up @@ -359,7 +361,7 @@ func (r *TaskActionReconciler) buildActionEvent(

event := &workflow.ActionEvent{
Id: actionID,
Attempt: 1, // TODO(nary): wire retry attempt once retry state is available in executor status.
Attempt: observedAttempts(taskAction),
Phase: phaseToActionPhase(phaseInfo.Phase()),
Version: phaseInfo.Version(),
UpdatedTime: updatedTime,
Expand All @@ -373,12 +375,27 @@ func (r *TaskActionReconciler) buildActionEvent(
if info != nil {
event.LogInfo = info.Logs
event.LogContext = info.LogContext
event.CacheStatus = cacheStatusFromExternalResources(info.ExternalResources)
}
event.CacheStatus = observedCacheStatus(info)

return event
}

func observedAttempts(taskAction *flyteorgv1.TaskAction) uint32 {
if taskAction.Status.Attempts > 0 {
return taskAction.Status.Attempts
}
// if attempts is not set, default to 1
return 1
}

func observedCacheStatus(info *pluginsCore.TaskInfo) core.CatalogCacheStatus {
if info == nil {
return core.CatalogCacheStatus_CACHE_DISABLED
}
return cacheStatusFromExternalResources(info.ExternalResources)
}

func updatedTimestamp(info *pluginsCore.TaskInfo, history []flyteorgv1.PhaseTransition) *timestamppb.Timestamp {
if info != nil && info.OccurredAt != nil {
return timestamppb.New(*info.OccurredAt)
Expand Down Expand Up @@ -481,7 +498,9 @@ func taskActionStatusChanged(oldStatus, newStatus flyteorgv1.TaskActionStatus) b
if oldStatus.StateJSON != newStatus.StateJSON ||
oldStatus.PluginStateVersion != newStatus.PluginStateVersion ||
oldStatus.PluginPhase != newStatus.PluginPhase ||
oldStatus.PluginPhaseVersion != newStatus.PluginPhaseVersion {
oldStatus.PluginPhaseVersion != newStatus.PluginPhaseVersion ||
oldStatus.Attempts != newStatus.Attempts ||
oldStatus.CacheStatus != newStatus.CacheStatus {
return true
}

Expand Down
41 changes: 39 additions & 2 deletions executor/pkg/plugin/task_exec_metadata.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package plugin

import (
"fmt"

"google.golang.org/protobuf/proto"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -80,14 +82,17 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut
"_U_ORG_NAME": ta.Spec.Org,
"_U_RUN_BASE": ta.Spec.RunOutputBase,
}
generatedName := buildGeneratedName(ta)
retryAttempt := attemptToRetry(ta.Status.Attempts)
maxAttempts := maxAttemptsFromTaskTemplate(ta.Spec.TaskTemplate)

return &taskExecutionMetadata{
ownerID: types.NamespacedName{
Name: ta.Name,
Namespace: ta.Namespace,
},
taskExecutionID: &taskExecutionID{
generatedName: ta.Name,
generatedName: generatedName,
id: core.TaskExecutionIdentifier{
NodeExecutionId: &core.NodeExecutionIdentifier{
ExecutionId: &core.WorkflowExecutionIdentifier{
Expand All @@ -98,6 +103,7 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut
},
NodeId: ta.Spec.ActionName,
},
RetryAttempt: retryAttempt,
},
},
namespace: ta.Namespace,
Expand All @@ -109,13 +115,44 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut
},
labels: pluginsUtils.UnionMaps(ta.Labels, injectLabels),
annotations: pluginsUtils.UnionMaps(ta.Annotations, secretsMap),
maxAttempts: 1,
maxAttempts: maxAttempts,
overrides: overrides,
envVars: envVars,
securityContext: securityContext,
}, nil
}

func buildGeneratedName(ta *flyteorgv1.TaskAction) string {
return fmt.Sprintf("%s-%d", ta.Name, attemptToRetry(ta.Status.Attempts))
}

// attemptToRetry convert attempt to retry count
func attemptToRetry(attempt uint32) uint32 {
if attempt <= 1 {
return 0
}
return attempt - 1
}

// maxAttemptsFromTaskTemplate give the max attempts (retries + 1) from the task template.
func maxAttemptsFromTaskTemplate(data []byte) uint32 {
if len(data) == 0 {
return 1
}

tmpl := &core.TaskTemplate{}
if err := proto.Unmarshal(data, tmpl); err != nil {
return 1
}

md := tmpl.GetMetadata()
if md == nil || md.GetRetries() == nil {
return 1
}

return md.GetRetries().GetRetries() + 1
}

// buildOverridesFromTaskTemplate deserializes the task template and extracts resource requirements.
func buildOverridesFromTaskTemplate(data []byte) *taskOverrides {
if len(data) == 0 {
Expand Down
Loading
Loading