Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 74 additions & 10 deletions runs/repository/impl/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,13 @@ func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest

// AbortRun aborts a run and all its actions
func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error {
// Update the run action to aborted
now := time.Now()
updates := map[string]interface{}{
"phase": int32(common.ActionPhase_ACTION_PHASE_ABORTED),
"updated_at": time.Now(),
"phase": int32(common.ActionPhase_ACTION_PHASE_ABORTED),
"updated_at": now,
"abort_requested_at": now,
"abort_attempt_count": 0,
"abort_reason": reason,
}

result := r.db.WithContext(ctx).
Expand All @@ -270,7 +273,7 @@ func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier,
return fmt.Errorf("failed to abort run: %w", result.Error)
}

// Notify subscribers
// Notify run subscribers.
r.notifyRunUpdate(ctx, runID)

logger.Infof(ctx, "Aborted run: %s/%s/%s/%s", runID.Org, runID.Project, runID.Domain, runID.Name)
Expand Down Expand Up @@ -509,28 +512,89 @@ func (r *actionRepo) UpdateActionPhase(

// AbortAction aborts a specific action
func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error {
now := time.Now()
updates := map[string]interface{}{
"phase": int32(common.ActionPhase_ACTION_PHASE_ABORTED),
"updated_at": time.Now(),
"phase": int32(common.ActionPhase_ACTION_PHASE_ABORTED),
"updated_at": now,
"abort_requested_at": now,
"abort_attempt_count": 0,
"abort_reason": reason,
}

result := r.db.WithContext(ctx).
Model(&models.Action{}).
Where("org = ? AND project = ? AND domain = ? AND name = ?",
actionID.Run.Org, actionID.Run.Project, actionID.Run.Domain, actionID.Name).
Where("org = ? AND project = ? AND domain = ? AND run_name = ? AND name = ?",
actionID.Run.Org, actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name).
Updates(updates)

if result.Error != nil {
return fmt.Errorf("failed to abort action: %w", result.Error)
}

// Notify subscribers
// Notify action subscribers.
r.notifyActionUpdate(ctx, actionID)

logger.Infof(ctx, "Aborted action: %s", actionID.Name)
return nil
}

// ListPendingAborts returns all actions that have abort_requested_at set (i.e. awaiting pod termination).
func (r *actionRepo) ListPendingAborts(ctx context.Context) ([]*models.Action, error) {
var actions []*models.Action
result := r.db.WithContext(ctx).
Where("abort_requested_at IS NOT NULL").
Find(&actions)
if result.Error != nil {
return nil, fmt.Errorf("failed to list pending aborts: %w", result.Error)
}
return actions, nil
}

// MarkAbortAttempt increments abort_attempt_count and returns the new value.
// Called by the reconciler before each actionsClient.Abort call.
func (r *actionRepo) MarkAbortAttempt(ctx context.Context, actionID *common.ActionIdentifier) (int, error) {
result := r.db.WithContext(ctx).
Model(&models.Action{}).
Where("org = ? AND project = ? AND domain = ? AND run_name = ? AND name = ?",
actionID.Run.Org, actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name).
Updates(map[string]interface{}{
"abort_attempt_count": gorm.Expr("abort_attempt_count + 1"),
"updated_at": time.Now(),
})
if result.Error != nil {
return 0, fmt.Errorf("failed to mark abort attempt: %w", result.Error)
}

// Re-fetch the updated count.
var action models.Action
if err := r.db.WithContext(ctx).
Select("abort_attempt_count").
Where("org = ? AND project = ? AND domain = ? AND run_name = ? AND name = ?",
actionID.Run.Org, actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name).
First(&action).Error; err != nil {
return 0, fmt.Errorf("failed to read abort attempt count: %w", err)
}
return action.AbortAttemptCount, nil
}

// ClearAbortRequest clears abort_requested_at (and resets counters) once the pod is confirmed terminated.
func (r *actionRepo) ClearAbortRequest(ctx context.Context, actionID *common.ActionIdentifier) error {
result := r.db.WithContext(ctx).
Model(&models.Action{}).
Where("org = ? AND project = ? AND domain = ? AND run_name = ? AND name = ?",
actionID.Run.Org, actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name).
Updates(map[string]interface{}{
"abort_requested_at": nil,
"abort_attempt_count": 0,
"abort_reason": nil,
"updated_at": time.Now(),
})
if result.Error != nil {
return fmt.Errorf("failed to clear abort request: %w", result.Error)
}
return nil
}

// UpdateActionState updates the state of an action
func (r *actionRepo) UpdateActionState(ctx context.Context, actionID *common.ActionIdentifier, state string) error {
// Parse the state JSON to extract the phase
Expand Down Expand Up @@ -936,11 +1000,11 @@ func (r *actionRepo) startPostgresListener() {
select {
case ch <- notif.Extra:
default:
// Channel full, skip this subscriber
logger.Warnf(context.Background(), "Action subscriber channel full, dropping notification")
}
}
r.mu.RUnlock()

}

case <-time.After(90 * time.Second):
Expand Down
3 changes: 3 additions & 0 deletions runs/repository/impl/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func setupActionDB(t *testing.T) *gorm.DB {
action_details BLOB,
detailed_info BLOB,
run_spec BLOB,
abort_requested_at DATETIME,
abort_attempt_count INTEGER NOT NULL DEFAULT 0,
abort_reason TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
ended_at DATETIME,
Expand Down
5 changes: 5 additions & 0 deletions runs/repository/interfaces/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ type ActionRepo interface {
UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error
AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error

// Abort reconciliation — used by the background AbortReconciler.
ListPendingAborts(ctx context.Context) ([]*models.Action, error)
MarkAbortAttempt(ctx context.Context, actionID *common.ActionIdentifier) (attemptCount int, err error)
ClearAbortRequest(ctx context.Context, actionID *common.ActionIdentifier) error

// Watch operations (for streaming)
WatchRunUpdates(ctx context.Context, runID *common.RunIdentifier, updates chan<- *models.Run, errs chan<- error)
WatchAllRunUpdates(ctx context.Context, updates chan<- *models.Run, errs chan<- error)
Expand Down
162 changes: 162 additions & 0 deletions runs/repository/mocks/action_repo.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions runs/repository/models/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ type Action struct {
// interruptible, cluster, etc.) for this action's run.
RunSpec []byte `gorm:"type:bytea" db:"run_spec"`

// Abort tracking — set when a user requests abort; cleared once the pod is confirmed terminated.
AbortRequestedAt *time.Time `gorm:"index:idx_actions_abort_pending" db:"abort_requested_at"`
AbortAttemptCount int `gorm:"not null;default:0" db:"abort_attempt_count"`
AbortReason *string `db:"abort_reason"`

// Timestamps
// CreatedAt is set by the DB (NOW()) on insert — represents action start time.
CreatedAt time.Time `gorm:"not null;default:CURRENT_TIMESTAMP;index:idx_actions_created" db:"created_at"`
Expand Down
Loading
Loading