Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dbos/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Client interface {
CancelWorkflow(workflowID string) error
ResumeWorkflow(workflowID string) (WorkflowHandle[any], error)
ForkWorkflow(input ForkWorkflowInput) (WorkflowHandle[any], error)
GetWorkflowSteps(workflowID string) ([]StepInfo, error)
Shutdown(timeout time.Duration) // Simply close the system DB connection pool
}

Expand Down Expand Up @@ -295,6 +296,11 @@ func (c *client) ForkWorkflow(input ForkWorkflowInput) (WorkflowHandle[any], err
return c.dbosCtx.ForkWorkflow(c.dbosCtx, input)
}

// GetWorkflowSteps retrieves the execution steps of a workflow.
func (c *client) GetWorkflowSteps(workflowID string) ([]StepInfo, error) {
return c.dbosCtx.GetWorkflowSteps(c.dbosCtx, workflowID)
}

// Shutdown gracefully shuts down the client and closes the system database connection.
func (c *client) Shutdown(timeout time.Duration) {
// Get the concrete dbosContext to access internal fields
Expand Down
66 changes: 66 additions & 0 deletions dbos/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -928,3 +928,69 @@ func TestListWorkflows(t *testing.T) {
// Verify all queue entries are cleaned up
require.True(t, queueEntriesAreCleanedUp(serverCtx), "expected queue entries to be cleaned up after list workflows tests")
}

func TestGetWorkflowSteps(t *testing.T) {
// Setup server context
serverCtx := setupDBOS(t, true, true)

// Create queue for communication
queue := NewWorkflowQueue(serverCtx, "get-workflow-steps-queue")

// Workflow with one step
stepFunction := func(ctx context.Context) (string, error) {
return "abc", nil
}

testWorkflow := func(ctx DBOSContext, input string) (string, error) {
result, err := RunAsStep(ctx, stepFunction, WithStepName("TestStep"))
if err != nil {
return "", err
}
return result, nil
}
RegisterWorkflow(serverCtx, testWorkflow, WithWorkflowName("TestWorkflow"))

// Launch server
err := Launch(serverCtx)
require.NoError(t, err)

// Setup client
databaseURL := getDatabaseURL()
config := ClientConfig{
DatabaseURL: databaseURL,
}
client, err := NewClient(context.Background(), config)
require.NoError(t, err)
t.Cleanup(func() {
if client != nil {
client.Shutdown(30 * time.Second)
}
})

// Enqueue and run the workflow
workflowID := "test-get-workflow-steps"
handle, err := Enqueue[string, string](client, queue.Name, "TestWorkflow", "test-input", WithEnqueueWorkflowID(workflowID))
require.NoError(t, err)

// Wait for workflow to complete
result, err := handle.GetResult()
require.NoError(t, err)
assert.Equal(t, "abc", result)

// Test GetWorkflowSteps with loadOutput = true
stepsWithOutput, err := client.GetWorkflowSteps(workflowID)
require.NoError(t, err)
require.Len(t, stepsWithOutput, 1, "expected exactly 1 step")

step := stepsWithOutput[0]
assert.Equal(t, 0, step.StepID, "expected step ID to be 0")
assert.Equal(t, "TestStep", step.StepName, "expected step name to be set")
assert.Nil(t, step.Error, "expected no error in step")
assert.Equal(t, "", step.ChildWorkflowID, "expected no child workflow ID")

// Verify the output wasn't loaded
require.Nil(t, step.Output, "expected output not to be loaded")

// Verify all queue entries are cleaned up
require.True(t, queueEntriesAreCleanedUp(serverCtx), "expected queue entries to be cleaned up after get workflow steps test")
}
19 changes: 15 additions & 4 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ func WithValue(ctx DBOSContext, key, val any) DBOSContext {
}
// Will do nothing if the concrete type is not dbosContext
if dbosCtx, ok := ctx.(*dbosContext); ok {
return &dbosContext{
launched := dbosCtx.launched.Load()
childCtx := &dbosContext{
ctx: context.WithValue(dbosCtx.ctx, key, val), // Spawn a new child context with the value set
logger: dbosCtx.logger,
systemDB: dbosCtx.systemDB,
Expand All @@ -205,6 +206,8 @@ func WithValue(ctx DBOSContext, key, val any) DBOSContext {
executorID: dbosCtx.executorID,
applicationID: dbosCtx.applicationID,
}
childCtx.launched.Store(launched)
return childCtx
}
return nil
}
Expand All @@ -217,7 +220,10 @@ func WithoutCancel(ctx DBOSContext) DBOSContext {
return nil
}
if dbosCtx, ok := ctx.(*dbosContext); ok {
return &dbosContext{
launched := dbosCtx.launched.Load()
// Create a new context that is not canceled when the parent is canceled
// but retains all other values
childCtx := &dbosContext{
ctx: context.WithoutCancel(dbosCtx.ctx),
logger: dbosCtx.logger,
systemDB: dbosCtx.systemDB,
Expand All @@ -228,6 +234,8 @@ func WithoutCancel(ctx DBOSContext) DBOSContext {
executorID: dbosCtx.executorID,
applicationID: dbosCtx.applicationID,
}
childCtx.launched.Store(launched)
return childCtx
}
return nil
}
Expand All @@ -240,8 +248,9 @@ func WithTimeout(ctx DBOSContext, timeout time.Duration) (DBOSContext, context.C
return nil, func() {}
}
if dbosCtx, ok := ctx.(*dbosContext); ok {
launched := dbosCtx.launched.Load()
newCtx, cancelFunc := context.WithTimeoutCause(dbosCtx.ctx, timeout, errors.New("DBOS context timeout"))
return &dbosContext{
childCtx := &dbosContext{
ctx: newCtx,
logger: dbosCtx.logger,
systemDB: dbosCtx.systemDB,
Expand All @@ -251,7 +260,9 @@ func WithTimeout(ctx DBOSContext, timeout time.Duration) (DBOSContext, context.C
applicationVersion: dbosCtx.applicationVersion,
executorID: dbosCtx.executorID,
applicationID: dbosCtx.applicationID,
}, cancelFunc
}
childCtx.launched.Store(launched)
return childCtx, cancelFunc
}
return nil, func() {}
}
Expand Down
12 changes: 3 additions & 9 deletions dbos/serialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ import (
"github.com/stretchr/testify/require"
)

/** Test serialization and deserialization
[x] Built in types
[x] User defined types (structs)
[x] Workflow inputs/outputs
[x] Step inputs/outputs
[x] Direct handlers, polling handler, list workflows results, get step infos
[x] Set/get event with user defined types
*/

// Builtin types
func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) {
return input, errors.New("step error")
Expand Down Expand Up @@ -76,6 +67,9 @@ func TestWorkflowEncoding(t *testing.T) {
RegisterWorkflow(executor, encodingWorkflowBuiltinTypes)
RegisterWorkflow(executor, encodingWorkflowStruct)

err := Launch(executor)
require.NoError(t, err)

t.Run("BuiltinTypes", func(t *testing.T) {
// Test a workflow that uses a built-in type (string)
directHandle, err := RunWorkflow(executor, encodingWorkflowBuiltinTypes, "test")
Expand Down
15 changes: 10 additions & 5 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type systemDatabase interface {
// Steps
recordOperationResult(ctx context.Context, input recordOperationResultDBInput) error
checkOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error)
getWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error)
getWorkflowSteps(ctx context.Context, input getWorkflowStepsInput) ([]StepInfo, error)

// Communication (special steps)
send(ctx context.Context, input WorkflowSendInput) error
Expand Down Expand Up @@ -1457,13 +1457,18 @@ type StepInfo struct {
ChildWorkflowID string // The ID of a child workflow spawned by this step (if applicable)
}

func (s *sysDB) getWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) {
type getWorkflowStepsInput struct {
workflowID string
loadOutput bool
}

func (s *sysDB) getWorkflowSteps(ctx context.Context, input getWorkflowStepsInput) ([]StepInfo, error) {
query := fmt.Sprintf(`SELECT function_id, function_name, output, error, child_workflow_id
FROM %s.operation_outputs
WHERE workflow_uuid = $1
ORDER BY function_id ASC`, pgx.Identifier{s.schema}.Sanitize())

rows, err := s.pool.Query(ctx, query, workflowID)
rows, err := s.pool.Query(ctx, query, input.workflowID)
if err != nil {
return nil, fmt.Errorf("failed to query workflow steps: %w", err)
}
Expand All @@ -1481,8 +1486,8 @@ func (s *sysDB) getWorkflowSteps(ctx context.Context, workflowID string) ([]Step
return nil, fmt.Errorf("failed to scan step row: %w", err)
}

// Deserialize output if present
if outputString != nil {
// Deserialize output if present and loadOutput is true
if input.loadOutput && outputString != nil {
output, err := deserialize(outputString)
if err != nil {
return nil, fmt.Errorf("failed to deserialize output: %w", err)
Expand Down
15 changes: 13 additions & 2 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -1956,14 +1956,25 @@ func ListWorkflows(ctx DBOSContext, opts ...ListWorkflowsOption) ([]WorkflowStat
}

func (c *dbosContext) GetWorkflowSteps(_ DBOSContext, workflowID string) ([]StepInfo, error) {
var loadOutput bool
if c.launched.Load() {
loadOutput = true
} else {
loadOutput = false
}
getWorkflowStepsInput := getWorkflowStepsInput{
workflowID: workflowID,
loadOutput: loadOutput,
}

workflowState, ok := c.Value(workflowStateKey).(*workflowState)
isWithinWorkflow := ok && workflowState != nil
if isWithinWorkflow {
return RunAsStep(c, func(ctx context.Context) ([]StepInfo, error) {
return c.systemDB.getWorkflowSteps(ctx, workflowID)
return c.systemDB.getWorkflowSteps(ctx, getWorkflowStepsInput)
}, WithStepName("DBOS.getWorkflowSteps"))
} else {
return c.systemDB.getWorkflowSteps(c, workflowID)
return c.systemDB.getWorkflowSteps(c, getWorkflowStepsInput)
}
}

Expand Down
Loading
Loading