Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions cmd/dbos/cli_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,27 @@ func testListWorkflows(t *testing.T, cliPath string, baseArgs []string, dbRole s
workflowID, exists := response["workflow_id"]
assert.True(t, exists, "Response should contain workflow_id")
assert.NotEmpty(t, workflowID, "Workflow ID should not be empty")

// Wait until the QueueWorkflow has enqueued all workflows.
// This is to avoid a race condition where the test checks for queued workflows
// before they are all enqueued.
require.Eventually(t, func() bool {
args := append([]string{"workflow", "list", "--queue", "example-queue"}, baseArgs...)
cmd := exec.Command(cliPath, args...)
cmd.Env = append(os.Environ(), "DBOS_SYSTEM_DATABASE_URL="+getDatabaseURL(dbRole))

output, err := cmd.CombinedOutput()
if err != nil {
return false
}
var workflows []dbos.WorkflowStatus
err = json.Unmarshal(output, &workflows)
if err != nil {
return false
}
return len(workflows) >= 10
}, 5*time.Second, 500*time.Millisecond, "Should find at least 10 workflows in the queue")

// Get the current time for time-based filtering
currentTime := time.Now()

Expand Down
98 changes: 26 additions & 72 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ type systemDatabase interface {
// Child workflows
recordChildWorkflow(ctx context.Context, input recordChildWorkflowDBInput) error
checkChildWorkflow(ctx context.Context, workflowUUID string, functionID int) (*string, error)
recordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error

// Steps
recordOperationResult(ctx context.Context, input recordOperationResultDBInput) error
Expand Down Expand Up @@ -1308,51 +1307,47 @@ func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string, poll
}

type recordOperationResultDBInput struct {
workflowID string
stepID int
stepName string
output *string
err error
tx pgx.Tx
startedAt time.Time
completedAt time.Time
workflowID string
childWorkflowID string
stepID int
stepName string
output *string
err error
tx pgx.Tx
startedAt time.Time
completedAt time.Time
}

func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperationResultDBInput) error {
startedAtMs := input.startedAt.UnixMilli()
completedAtMs := input.completedAt.UnixMilli()

query := fmt.Sprintf(`INSERT INTO %s.operation_outputs
(workflow_uuid, function_id, output, error, function_name, started_at_epoch_ms, completed_at_epoch_ms)
VALUES ($1, $2, $3, $4, $5, $6, $7)`, pgx.Identifier{s.schema}.Sanitize())

var errorString *string
if input.err != nil {
e := input.err.Error()
errorString = &e
}

columns := []string{"workflow_uuid", "function_id", "output", "error", "function_name", "started_at_epoch_ms", "completed_at_epoch_ms"}
placeholders := []string{"$1", "$2", "$3", "$4", "$5", "$6", "$7"}
args := []any{input.workflowID, input.stepID, input.output, errorString, input.stepName, startedAtMs, completedAtMs}
argCounter := 7

if input.childWorkflowID != "" {
columns = append(columns, "child_workflow_id")
argCounter++
placeholders = append(placeholders, fmt.Sprintf("$%d", argCounter))
args = append(args, input.childWorkflowID)
}

query := fmt.Sprintf(`INSERT INTO %s.operation_outputs (%s) VALUES (%s)`,
pgx.Identifier{s.schema}.Sanitize(), strings.Join(columns, ", "), strings.Join(placeholders, ", "))

var err error
if input.tx != nil {
_, err = input.tx.Exec(ctx, query,
input.workflowID,
input.stepID,
input.output,
errorString,
input.stepName,
startedAtMs,
completedAtMs,
)
_, err = input.tx.Exec(ctx, query, args...)
} else {
_, err = s.pool.Exec(ctx, query,
input.workflowID,
input.stepID,
input.output,
errorString,
input.stepName,
startedAtMs,
completedAtMs,
)
_, err = s.pool.Exec(ctx, query, args...)
}

if err != nil {
Expand Down Expand Up @@ -1435,47 +1430,6 @@ func (s *sysDB) checkChildWorkflow(ctx context.Context, workflowID string, funct
return childWorkflowID, nil
}

type recordChildGetResultDBInput struct {
parentWorkflowID string
childWorkflowID string
stepID int
output *string
err error
startedAt time.Time
completedAt time.Time
}

func (s *sysDB) recordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error {
startedAtMs := input.startedAt.UnixMilli()
completedAtMs := input.completedAt.UnixMilli()

query := fmt.Sprintf(`INSERT INTO %s.operation_outputs
(workflow_uuid, function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT DO NOTHING`, pgx.Identifier{s.schema}.Sanitize())

var errorString *string
if input.err != nil {
e := input.err.Error()
errorString = &e
}

_, err := s.pool.Exec(ctx, query,
input.parentWorkflowID,
input.stepID,
"DBOS.getResult",
input.output,
errorString,
input.childWorkflowID,
startedAtMs,
completedAtMs,
)
if err != nil {
return fmt.Errorf("failed to record get result: %w", err)
}
return nil
}

/*******************************/
/******* STEPS ********/
/*******************************/
Expand Down
38 changes: 20 additions & 18 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,18 @@ func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R], startTime
if encErr != nil {
return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("serializing child workflow result: %w", encErr))
}
recordGetResultInput := recordChildGetResultDBInput{
parentWorkflowID: workflowState.workflowID,
childWorkflowID: h.workflowID,
stepID: workflowState.nextStepID(),
output: encodedOutput,
err: outcome.err,
startedAt: startTime,
completedAt: completedTime,
recordGetResultInput := recordOperationResultDBInput{
workflowID: workflowState.workflowID,
childWorkflowID: h.workflowID,
stepID: workflowState.nextStepID(),
output: encodedOutput,
err: outcome.err,
startedAt: startTime,
completedAt: completedTime,
stepName: "DBOS.getResult",
}
recordResultErr := retry(h.dbosContext, func() error {
return h.dbosContext.(*dbosContext).systemDB.recordChildGetResult(h.dbosContext, recordGetResultInput)
return h.dbosContext.(*dbosContext).systemDB.recordOperationResult(h.dbosContext, recordGetResultInput)
}, withRetrierLogger(h.dbosContext.(*dbosContext).logger))
if recordResultErr != nil {
h.dbosContext.(*dbosContext).logger.Error("failed to record get result", "error", recordResultErr)
Expand Down Expand Up @@ -309,17 +310,18 @@ func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error)
workflowState, ok := h.dbosContext.Value(workflowStateKey).(*workflowState)
isWithinWorkflow := ok && workflowState != nil
if isWithinWorkflow {
recordGetResultInput := recordChildGetResultDBInput{
parentWorkflowID: workflowState.workflowID,
childWorkflowID: h.workflowID,
stepID: workflowState.nextStepID(),
output: encodedStr,
err: err,
startedAt: startTime,
completedAt: completedTime,
recordGetResultInput := recordOperationResultDBInput{
workflowID: workflowState.workflowID,
childWorkflowID: h.workflowID,
stepID: workflowState.nextStepID(),
output: encodedStr,
err: err,
startedAt: startTime,
completedAt: completedTime,
stepName: "DBOS.getResult",
}
recordResultErr := retry(h.dbosContext, func() error {
return h.dbosContext.(*dbosContext).systemDB.recordChildGetResult(h.dbosContext, recordGetResultInput)
return h.dbosContext.(*dbosContext).systemDB.recordOperationResult(h.dbosContext, recordGetResultInput)
}, withRetrierLogger(h.dbosContext.(*dbosContext).logger))
if recordResultErr != nil {
h.dbosContext.(*dbosContext).logger.Error("failed to record get result", "error", recordResultErr)
Expand Down
Loading