diff --git a/actions/k8s/client.go b/actions/k8s/client.go index fa84fe8e7c..c60783067e 100644 --- a/actions/k8s/client.go +++ b/actions/k8s/client.go @@ -189,8 +189,7 @@ func (c *ActionsClient) GetState(ctx context.Context, actionID *common.ActionIde return taskAction.Status.StateJSON, nil } -// PutState updates the state JSON for a TaskAction. -// attempt and status are accepted for future use (e.g. recording to RunService). +// PutState updates the state JSON and latest attempt metadata for a TaskAction. func (c *ActionsClient) PutState(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32, status *workflow.ActionStatus, stateJSON string) error { taskActionName := buildTaskActionName(actionID) @@ -210,6 +209,10 @@ func (c *ActionsClient) PutState(ctx context.Context, actionID *common.ActionIde // Update state JSON taskAction.Status.StateJSON = stateJSON + if status != nil { + taskAction.Status.Attempts = status.GetAttempts() + taskAction.Status.CacheStatus = status.GetCacheStatus() + } // Update status subresource if err := c.k8sClient.Status().Update(ctx, taskAction); err != nil { @@ -499,7 +502,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut statusReq := &workflow.UpdateActionStatusRequest{ ActionId: update.ActionID, Status: &workflow.ActionStatus{ - Phase: update.Phase, + Phase: update.Phase, + Attempts: taskAction.Status.Attempts, + CacheStatus: taskAction.Status.CacheStatus, }, } if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil { @@ -548,15 +553,14 @@ func GetPhaseFromConditions(taskAction *executorv1.TaskAction) common.ActionPhas } // buildTaskActionName generates a Kubernetes-compliant name for the TaskAction. -// For root actions (where action name == run name), the name is -a0-0. -// For child actions, the name is --0. -// The trailing "0" is the attempt number (0-indexed; hardcoded until retry support is added). +// For root actions (where action name == run name), the name is -a0. +// For child actions, the name is -. func buildTaskActionName(actionID *common.ActionIdentifier) string { isRoot := actionID.Name == actionID.Run.Name if isRoot { - return fmt.Sprintf("%s-a0-0", actionID.Run.Name) + return fmt.Sprintf("%s-a0", actionID.Run.Name) } - return fmt.Sprintf("%s-%s-0", actionID.Run.Name, actionID.Name) + return fmt.Sprintf("%s-%s", actionID.Run.Name, actionID.Name) } // buildNamespace returns the Kubernetes namespace for a run: "-". diff --git a/actions/k8s/client_test.go b/actions/k8s/client_test.go index d7bd483fea..b921e0e04a 100644 --- a/actions/k8s/client_test.go +++ b/actions/k8s/client_test.go @@ -14,6 +14,7 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/fastcheck" "github.com/flyteorg/flyte/v2/flytestdlib/promutils" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" runmocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect/mocks" ) @@ -130,6 +131,32 @@ func TestNotifyRunService_NilFilter(t *testing.T) { mockClient.AssertNumberOfCalls(t, "RecordAction", 2) } +func TestNotifyRunService_UpdateActionStatusIncludesAttemptsAndCacheStatus(t *testing.T) { + ctx := context.Background() + + mockClient := runmocks.NewInternalRunServiceClient(t) + c := &ActionsClient{ + runClient: mockClient, + subscribers: make(map[string]map[chan *ActionUpdate]struct{}), + } + + ta, update := newTestActionUpdate("action-4") + ta.Status.Attempts = 3 + ta.Status.CacheStatus = core.CatalogCacheStatus_CACHE_HIT + update.Phase = common.ActionPhase_ACTION_PHASE_SUCCEEDED + + mockClient.On("UpdateActionStatus", mock.Anything, mock.MatchedBy(func(req *connect.Request[workflow.UpdateActionStatusRequest]) bool { + status := req.Msg.GetStatus() + return status.GetPhase() == common.ActionPhase_ACTION_PHASE_SUCCEEDED && + status.GetAttempts() == 3 && + status.GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT + })).Return(&connect.Response[workflow.UpdateActionStatusResponse]{}, nil).Once() + + c.notifyRunService(ctx, ta, update, watch.Modified) + + mockClient.AssertNumberOfCalls(t, "UpdateActionStatus", 1) +} + func TestBuildTaskActionName(t *testing.T) { runID := &common.RunIdentifier{ Org: "org", @@ -138,13 +165,13 @@ func TestBuildTaskActionName(t *testing.T) { Name: "rabc123", } - t.Run("root action uses a0-0 suffix", func(t *testing.T) { + t.Run("root action uses a0 suffix", func(t *testing.T) { // Root: action name == run name actionID := &common.ActionIdentifier{ Run: runID, Name: runID.Name, } - assert.Equal(t, "rabc123-a0-0", buildTaskActionName(actionID)) + assert.Equal(t, "rabc123-a0", buildTaskActionName(actionID)) }) t.Run("child action includes action name", func(t *testing.T) { @@ -152,7 +179,7 @@ func TestBuildTaskActionName(t *testing.T) { Run: runID, Name: "train", } - assert.Equal(t, "rabc123-train-0", buildTaskActionName(actionID)) + assert.Equal(t, "rabc123-train", buildTaskActionName(actionID)) }) } diff --git a/charts/flyte-binary/templates/crds/taskaction.yaml b/charts/flyte-binary/templates/crds/taskaction.yaml index c0be44bf43..aeb05c0e2f 100644 --- a/charts/flyte-binary/templates/crds/taskaction.yaml +++ b/charts/flyte-binary/templates/crds/taskaction.yaml @@ -231,6 +231,14 @@ spec: - phase type: object type: array + attempts: + description: Attempts is the latest observed action attempt number, + starting from 1. + type: integer + cacheStatus: + description: CacheStatus is the latest observed cache lookup result + for this action. + type: integer pluginPhase: description: PluginPhase is a human-readable representation of the plugin's current phase. diff --git a/docker/sandbox-bundled/manifests/complete.yaml b/docker/sandbox-bundled/manifests/complete.yaml index a52acc7e7f..3dee1dc880 100644 --- a/docker/sandbox-bundled/manifests/complete.yaml +++ b/docker/sandbox-bundled/manifests/complete.yaml @@ -141,6 +141,14 @@ spec: status: description: status defines the observed state of TaskAction properties: + attempts: + description: Attempts is the latest observed action attempt number, + starting from 1. + type: integer + cacheStatus: + description: CacheStatus is the latest observed cache lookup result + for this action. + type: integer conditions: description: |- conditions represent the current state of the TaskAction resource. @@ -1049,7 +1057,7 @@ type: Opaque --- apiVersion: v1 data: - haSharedSecret: RGhWVnFVakdvTU93amsxbA== + haSharedSecret: bHZNNWFFbjVJYzlZSHJaQg== proxyPassword: "" proxyUsername: "" kind: Secret @@ -1575,7 +1583,7 @@ spec: metadata: annotations: checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 - checksum/secret: d75874ebdfdca751387f7a272a14ca6f600f03189ebda404575c1645aa7e399a + checksum/secret: db8101eb87292a64110c1efe3ac20b104e11ec74eeecdadd0641e90529028080 labels: app: docker-registry release: flyte-sandbox diff --git a/docker/sandbox-bundled/manifests/dev.yaml b/docker/sandbox-bundled/manifests/dev.yaml index b89d2f0278..d6caad8060 100644 --- a/docker/sandbox-bundled/manifests/dev.yaml +++ b/docker/sandbox-bundled/manifests/dev.yaml @@ -141,6 +141,14 @@ spec: status: description: status defines the observed state of TaskAction properties: + attempts: + description: Attempts is the latest observed action attempt number, + starting from 1. + type: integer + cacheStatus: + description: CacheStatus is the latest observed cache lookup result + for this action. + type: integer conditions: description: |- conditions represent the current state of the TaskAction resource. @@ -763,7 +771,7 @@ metadata: --- apiVersion: v1 data: - haSharedSecret: TDY5Y2lPS0xUMFBRNUZLMQ== + haSharedSecret: ZUdTa2I5SE8xNU5rUlZtdQ== proxyPassword: "" proxyUsername: "" kind: Secret @@ -1192,7 +1200,7 @@ spec: metadata: annotations: checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 - checksum/secret: e5c345361f6169f960277db5c368b1f1c2a7ea280b8e17a066ac1c369806b160 + checksum/secret: f57b9b6e36305b7992e6f671cc05f984c1bc49359d77acdff1b13f761c03699c labels: app: docker-registry release: flyte-sandbox diff --git a/executor/api/v1/taskaction_types.go b/executor/api/v1/taskaction_types.go index ee59e14e72..930f778695 100644 --- a/executor/api/v1/taskaction_types.go +++ b/executor/api/v1/taskaction_types.go @@ -20,6 +20,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" ) @@ -225,6 +226,14 @@ type TaskActionStatus struct { // +optional PluginPhaseVersion uint32 `json:"pluginPhaseVersion,omitempty"` + // Attempts is the latest observed action attempt number, starting from 1. + // +optional + Attempts uint32 `json:"attempts,omitempty"` + + // CacheStatus is the latest observed cache lookup result for this action. + // +optional + CacheStatus core.CatalogCacheStatus `json:"cacheStatus,omitempty"` + // conditions represent the current state of the TaskAction resource. // Each condition has a unique type and reflects the status of a specific aspect of the resource. // diff --git a/executor/config/crd/bases/flyte.org_taskactions.yaml b/executor/config/crd/bases/flyte.org_taskactions.yaml index 2301767872..b8633e7623 100644 --- a/executor/config/crd/bases/flyte.org_taskactions.yaml +++ b/executor/config/crd/bases/flyte.org_taskactions.yaml @@ -231,6 +231,14 @@ spec: - phase type: object type: array + attempts: + description: Attempts is the latest observed action attempt number, + starting from 1. + type: integer + cacheStatus: + description: CacheStatus is the latest observed cache lookup result + for this action. + type: integer pluginPhase: description: PluginPhase is a human-readable representation of the plugin's current phase. diff --git a/executor/pkg/controller/taskaction_controller.go b/executor/pkg/controller/taskaction_controller.go index 44a5002a23..843b5e56b4 100644 --- a/executor/pkg/controller/taskaction_controller.go +++ b/executor/pkg/controller/taskaction_controller.go @@ -211,6 +211,8 @@ func (r *TaskActionReconciler) Reconcile(ctx context.Context, req ctrl.Request) taskAction.Status.PluginPhase = phaseInfo.Phase().String() taskAction.Status.PluginPhaseVersion = phaseInfo.Version() + taskAction.Status.Attempts = observedAttempts(taskAction) + taskAction.Status.CacheStatus = observedCacheStatus(phaseInfo.Info()) if err := r.updateTaskActionStatus(ctx, originalTaskActionInstance, taskAction, phaseInfo); err != nil { return ctrl.Result{}, err @@ -359,7 +361,7 @@ func (r *TaskActionReconciler) buildActionEvent( event := &workflow.ActionEvent{ Id: actionID, - Attempt: 1, // TODO(nary): wire retry attempt once retry state is available in executor status. + Attempt: observedAttempts(taskAction), Phase: phaseToActionPhase(phaseInfo.Phase()), Version: phaseInfo.Version(), UpdatedTime: updatedTime, @@ -373,12 +375,27 @@ func (r *TaskActionReconciler) buildActionEvent( if info != nil { event.LogInfo = info.Logs event.LogContext = info.LogContext - event.CacheStatus = cacheStatusFromExternalResources(info.ExternalResources) } + event.CacheStatus = observedCacheStatus(info) return event } +func observedAttempts(taskAction *flyteorgv1.TaskAction) uint32 { + if taskAction.Status.Attempts > 0 { + return taskAction.Status.Attempts + } + // if attempts is not set, default to 1 + return 1 +} + +func observedCacheStatus(info *pluginsCore.TaskInfo) core.CatalogCacheStatus { + if info == nil { + return core.CatalogCacheStatus_CACHE_DISABLED + } + return cacheStatusFromExternalResources(info.ExternalResources) +} + func updatedTimestamp(info *pluginsCore.TaskInfo, history []flyteorgv1.PhaseTransition) *timestamppb.Timestamp { if info != nil && info.OccurredAt != nil { return timestamppb.New(*info.OccurredAt) @@ -481,7 +498,9 @@ func taskActionStatusChanged(oldStatus, newStatus flyteorgv1.TaskActionStatus) b if oldStatus.StateJSON != newStatus.StateJSON || oldStatus.PluginStateVersion != newStatus.PluginStateVersion || oldStatus.PluginPhase != newStatus.PluginPhase || - oldStatus.PluginPhaseVersion != newStatus.PluginPhaseVersion { + oldStatus.PluginPhaseVersion != newStatus.PluginPhaseVersion || + oldStatus.Attempts != newStatus.Attempts || + oldStatus.CacheStatus != newStatus.CacheStatus { return true } diff --git a/executor/pkg/plugin/task_exec_metadata.go b/executor/pkg/plugin/task_exec_metadata.go index 1c94ea4b5b..5ec36c3515 100644 --- a/executor/pkg/plugin/task_exec_metadata.go +++ b/executor/pkg/plugin/task_exec_metadata.go @@ -1,6 +1,8 @@ package plugin import ( + "fmt" + "google.golang.org/protobuf/proto" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -80,6 +82,9 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut "_U_ORG_NAME": ta.Spec.Org, "_U_RUN_BASE": ta.Spec.RunOutputBase, } + generatedName := buildGeneratedName(ta) + retryAttempt := attemptToRetry(ta.Status.Attempts) + maxAttempts := maxAttemptsFromTaskTemplate(ta.Spec.TaskTemplate) return &taskExecutionMetadata{ ownerID: types.NamespacedName{ @@ -87,7 +92,7 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut Namespace: ta.Namespace, }, taskExecutionID: &taskExecutionID{ - generatedName: ta.Name, + generatedName: generatedName, id: core.TaskExecutionIdentifier{ NodeExecutionId: &core.NodeExecutionIdentifier{ ExecutionId: &core.WorkflowExecutionIdentifier{ @@ -98,6 +103,7 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut }, NodeId: ta.Spec.ActionName, }, + RetryAttempt: retryAttempt, }, }, namespace: ta.Namespace, @@ -109,13 +115,44 @@ func NewTaskExecutionMetadata(ta *flyteorgv1.TaskAction) (pluginsCore.TaskExecut }, labels: pluginsUtils.UnionMaps(ta.Labels, injectLabels), annotations: pluginsUtils.UnionMaps(ta.Annotations, secretsMap), - maxAttempts: 1, + maxAttempts: maxAttempts, overrides: overrides, envVars: envVars, securityContext: securityContext, }, nil } +func buildGeneratedName(ta *flyteorgv1.TaskAction) string { + return fmt.Sprintf("%s-%d", ta.Name, attemptToRetry(ta.Status.Attempts)) +} + +// attemptToRetry convert attempt to retry count +func attemptToRetry(attempt uint32) uint32 { + if attempt <= 1 { + return 0 + } + return attempt - 1 +} + +// maxAttemptsFromTaskTemplate give the max attempts (retries + 1) from the task template. +func maxAttemptsFromTaskTemplate(data []byte) uint32 { + if len(data) == 0 { + return 1 + } + + tmpl := &core.TaskTemplate{} + if err := proto.Unmarshal(data, tmpl); err != nil { + return 1 + } + + md := tmpl.GetMetadata() + if md == nil || md.GetRetries() == nil { + return 1 + } + + return md.GetRetries().GetRetries() + 1 +} + // buildOverridesFromTaskTemplate deserializes the task template and extracts resource requirements. func buildOverridesFromTaskTemplate(data []byte) *taskOverrides { if len(data) == 0 { diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index 55af25c797..e657d2f821 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -19,6 +19,7 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/database" "github.com/flyteorg/flyte/v2/flytestdlib/logger" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/runs/repository/interfaces" "github.com/flyteorg/flyte/v2/runs/repository/models" @@ -422,10 +423,19 @@ func (r *actionRepo) ListActions(ctx context.Context, runID *common.RunIdentifie // UpdateActionPhase updates the phase of an action. // endTime should be set when the action reaches a terminal phase. -func (r *actionRepo) UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, endTime *time.Time) error { +func (r *actionRepo) UpdateActionPhase( + ctx context.Context, + actionID *common.ActionIdentifier, + phase common.ActionPhase, + attempts uint32, + cacheStatus core.CatalogCacheStatus, + endTime *time.Time, +) error { updates := map[string]interface{}{ - "phase": phase, - "updated_at": time.Now(), + "phase": phase, + "attempts": attempts, + "cache_status": cacheStatus, + "updated_at": time.Now(), } if endTime != nil { if r.isPostgres { diff --git a/runs/repository/impl/action_test.go b/runs/repository/impl/action_test.go index ebd399f096..d244b110ee 100644 --- a/runs/repository/impl/action_test.go +++ b/runs/repository/impl/action_test.go @@ -3,9 +3,11 @@ package impl import ( "context" "testing" + "time" "github.com/flyteorg/flyte/v2/flytestdlib/database" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/runs/repository/models" @@ -109,6 +111,47 @@ func TestCreateRun(t *testing.T) { require.Error(t, err) } +func TestUpdateActionPhasePersistsAttemptsAndCacheStatus(t *testing.T) { + db := setupActionDB(t) + defer func() { _ = db.Exec("DELETE FROM actions") }() + actionRepo := NewActionRepo(db, database.DbConfig{}) + ctx := context.Background() + + actionID := &common.ActionIdentifier{ + Run: &common.RunIdentifier{ + Org: "org1", + Project: "proj1", + Domain: "domain1", + Name: "run1", + }, + Name: "action1", + } + + _, err := actionRepo.CreateAction(ctx, &workflow.ActionSpec{ + ActionId: actionID, + InputUri: "s3://bucket/input", + }, nil) + require.NoError(t, err) + + endTime := time.Now() + err = actionRepo.UpdateActionPhase( + ctx, + actionID, + common.ActionPhase_ACTION_PHASE_SUCCEEDED, + 3, + core.CatalogCacheStatus_CACHE_HIT, + &endTime, + ) + require.NoError(t, err) + + action, err := actionRepo.GetAction(ctx, actionID) + require.NoError(t, err) + assert.Equal(t, int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED), action.Phase) + assert.Equal(t, uint32(3), action.Attempts) + assert.Equal(t, core.CatalogCacheStatus_CACHE_HIT, action.CacheStatus) + assert.True(t, action.EndedAt.Valid) +} + func TestListRuns(t *testing.T) { db := setupActionDB(t) defer func() { _ = db.Exec("DELETE FROM actions") }() diff --git a/runs/repository/interfaces/action.go b/runs/repository/interfaces/action.go index 5ce69fa679..2d78a83133 100644 --- a/runs/repository/interfaces/action.go +++ b/runs/repository/interfaces/action.go @@ -5,6 +5,7 @@ import ( "time" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/runs/repository/models" ) @@ -23,7 +24,7 @@ type ActionRepo interface { ListEvents(ctx context.Context, actionID *common.ActionIdentifier, limit int) ([]*models.ActionEvent, error) GetAction(ctx context.Context, actionID *common.ActionIdentifier) (*models.Action, error) ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error) - UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, endTime *time.Time) error + UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error // Watch operations (for streaming) diff --git a/runs/repository/mocks/action_repo.go b/runs/repository/mocks/action_repo.go index a4a71b6f5a..f07ff4d903 100644 --- a/runs/repository/mocks/action_repo.go +++ b/runs/repository/mocks/action_repo.go @@ -7,6 +7,8 @@ import ( common "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + core "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" + mock "github.com/stretchr/testify/mock" models "github.com/flyteorg/flyte/v2/runs/repository/models" @@ -775,17 +777,17 @@ func (_c *ActionRepo_NotifyStateUpdate_Call) RunAndReturn(run func(context.Conte return _c } -// UpdateActionPhase provides a mock function with given fields: ctx, actionID, phase, endTime -func (_m *ActionRepo) UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, endTime *time.Time) error { - ret := _m.Called(ctx, actionID, phase, endTime) +// UpdateActionPhase provides a mock function with given fields: ctx, actionID, phase, attempts, cacheStatus, endTime +func (_m *ActionRepo) UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error { + ret := _m.Called(ctx, actionID, phase, attempts, cacheStatus, endTime) if len(ret) == 0 { panic("no return value specified for UpdateActionPhase") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *common.ActionIdentifier, common.ActionPhase, *time.Time) error); ok { - r0 = rf(ctx, actionID, phase, endTime) + if rf, ok := ret.Get(0).(func(context.Context, *common.ActionIdentifier, common.ActionPhase, uint32, core.CatalogCacheStatus, *time.Time) error); ok { + r0 = rf(ctx, actionID, phase, attempts, cacheStatus, endTime) } else { r0 = ret.Error(0) } @@ -802,14 +804,16 @@ type ActionRepo_UpdateActionPhase_Call struct { // - ctx context.Context // - actionID *common.ActionIdentifier // - phase common.ActionPhase +// - attempts uint32 +// - cacheStatus core.CatalogCacheStatus // - endTime *time.Time -func (_e *ActionRepo_Expecter) UpdateActionPhase(ctx interface{}, actionID interface{}, phase interface{}, endTime interface{}) *ActionRepo_UpdateActionPhase_Call { - return &ActionRepo_UpdateActionPhase_Call{Call: _e.mock.On("UpdateActionPhase", ctx, actionID, phase, endTime)} +func (_e *ActionRepo_Expecter) UpdateActionPhase(ctx interface{}, actionID interface{}, phase interface{}, attempts interface{}, cacheStatus interface{}, endTime interface{}) *ActionRepo_UpdateActionPhase_Call { + return &ActionRepo_UpdateActionPhase_Call{Call: _e.mock.On("UpdateActionPhase", ctx, actionID, phase, attempts, cacheStatus, endTime)} } -func (_c *ActionRepo_UpdateActionPhase_Call) Run(run func(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, endTime *time.Time)) *ActionRepo_UpdateActionPhase_Call { +func (_c *ActionRepo_UpdateActionPhase_Call) Run(run func(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time)) *ActionRepo_UpdateActionPhase_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*common.ActionIdentifier), args[2].(common.ActionPhase), args[3].(*time.Time)) + run(args[0].(context.Context), args[1].(*common.ActionIdentifier), args[2].(common.ActionPhase), args[3].(uint32), args[4].(core.CatalogCacheStatus), args[5].(*time.Time)) }) return _c } @@ -819,7 +823,7 @@ func (_c *ActionRepo_UpdateActionPhase_Call) Return(_a0 error) *ActionRepo_Updat return _c } -func (_c *ActionRepo_UpdateActionPhase_Call) RunAndReturn(run func(context.Context, *common.ActionIdentifier, common.ActionPhase, *time.Time) error) *ActionRepo_UpdateActionPhase_Call { +func (_c *ActionRepo_UpdateActionPhase_Call) RunAndReturn(run func(context.Context, *common.ActionIdentifier, common.ActionPhase, uint32, core.CatalogCacheStatus, *time.Time) error) *ActionRepo_UpdateActionPhase_Call { _c.Call.Return(run) return _c } diff --git a/runs/service/internal_run_service.go b/runs/service/internal_run_service.go index bfa0434748..a31c5aecdb 100644 --- a/runs/service/internal_run_service.go +++ b/runs/service/internal_run_service.go @@ -191,7 +191,14 @@ func (s *RunService) updateSingleActionStatus(ctx context.Context, req *workflow endTime = &t } - if err := s.repo.ActionRepo().UpdateActionPhase(ctx, req.GetActionId(), actionStatus.GetPhase(), endTime); err != nil { + if err := s.repo.ActionRepo().UpdateActionPhase( + ctx, + req.GetActionId(), + actionStatus.GetPhase(), + actionStatus.GetAttempts(), + actionStatus.GetCacheStatus(), + endTime, + ); err != nil { logger.Warnf(ctx, "UpdateActionStatus: failed to update action %s: %v", req.GetActionId().GetName(), err) return connect.NewError(connect.CodeInternal, err) }