diff --git a/dbos/dbos.go b/dbos/dbos.go index ef3d0be7..06daa504 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -71,11 +71,11 @@ type Config struct { AdminServer bool } -// ProcessConfig merges configuration from two sources in order of precedence: +// processConfig merges configuration from two sources in order of precedence: // 1. programmatic configuration // 2. environment variables // Finally, it applies default values if needed. -func ProcessConfig(inputConfig *Config) (*Config, error) { +func processConfig(inputConfig *Config) (*Config, error) { // First check required fields if len(inputConfig.DatabaseURL) == 0 { return nil, fmt.Errorf("missing required config field: databaseURL") @@ -133,7 +133,7 @@ func Initialize(inputConfig Config) error { } // Load & process the configuration - config, err := ProcessConfig(&inputConfig) + config, err := processConfig(&inputConfig) if err != nil { return newInitializationError(err.Error()) } diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go new file mode 100644 index 00000000..9a639415 --- /dev/null +++ b/dbos/dbos_test.go @@ -0,0 +1,87 @@ +package dbos + +import ( + "context" + "encoding/hex" + "maps" + "testing" +) + +func TestConfigValidationErrorTypes(t *testing.T) { + databaseURL := getDatabaseURL(t) + + t.Run("FailsWithoutAppName", func(t *testing.T) { + config := Config{ + DatabaseURL: databaseURL, + } + + err := Initialize(config) + if err == nil { + t.Fatal("expected error when app name is missing, but got none") + } + + dbosErr, ok := err.(*DBOSError) + if !ok { + t.Fatalf("expected DBOSError, got %T", err) + } + + if dbosErr.Code != InitializationError { + t.Fatalf("expected InitializationError code, got %v", dbosErr.Code) + } + + expectedMsg := "Error initializing DBOS Transact: missing required config field: appName" + if dbosErr.Message != expectedMsg { + t.Fatalf("expected error message '%s', got '%s'", expectedMsg, dbosErr.Message) + } + }) + + t.Run("FailsWithoutDatabaseURL", func(t *testing.T) { + config := Config{ + AppName: "test-app", + } + + err := Initialize(config) + if err == nil { + t.Fatal("expected error when database URL is missing, but got none") + } + + dbosErr, ok := err.(*DBOSError) + if !ok { + t.Fatalf("expected DBOSError, got %T", err) + } + + if dbosErr.Code != InitializationError { + t.Fatalf("expected InitializationError code, got %v", dbosErr.Code) + } + + expectedMsg := "Error initializing DBOS Transact: missing required config field: databaseURL" + if dbosErr.Message != expectedMsg { + t.Fatalf("expected error message '%s', got '%s'", expectedMsg, dbosErr.Message) + } + }) +} +func TestAppVersion(t *testing.T) { + if _, err := hex.DecodeString(_APP_VERSION); err != nil { + t.Fatalf("APP_VERSION is not a valid hex string: %v", err) + } + + // Save the original registry content + originalRegistry := make(map[string]workflowRegistryEntry) + maps.Copy(originalRegistry, registry) + + // Restore the registry after the test + defer func() { + registry = originalRegistry + }() + + // Replace the registry and verify the hash is different + registry = make(map[string]workflowRegistryEntry) + + WithWorkflow(func(ctx context.Context, input string) (string, error) { + return "new-registry-workflow-" + input, nil + }) + hash2 := computeApplicationVersion() + if _APP_VERSION == hash2 { + t.Fatalf("APP_VERSION hash did not change after replacing registry") + } +} diff --git a/dbos/system_database.go b/dbos/system_database.go index 43a4f9fd..b2c652f4 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -5,6 +5,7 @@ import ( "embed" "errors" "fmt" + "net/url" "strings" "sync" "time" @@ -98,7 +99,8 @@ func createDatabaseIfNotExists(databaseURL string) error { //go:embed migrations/*.sql var migrationFiles embed.FS -// TODO: must use the systemdb name +const _DBOS_MIGRATION_TABLE = "dbos_schema_migrations" + func runMigrations(databaseURL string) error { // Change the driver to pgx5 databaseURL = "pgx5://" + strings.TrimPrefix(databaseURL, "postgres://") @@ -109,6 +111,20 @@ func runMigrations(databaseURL string) error { return newInitializationError(fmt.Sprintf("failed to create migration source: %v", err)) } + // Add custom migration table name to avoid conflicts with user migrations + // Parse the URL to properly determine where to add the query parameter + parsedURL, err := url.Parse(databaseURL) + if err != nil { + return newInitializationError(fmt.Sprintf("failed to parse database URL: %v", err)) + } + + // Check if query parameters already exist + separator := "?" + if parsedURL.RawQuery != "" { + separator = "&" + } + databaseURL += separator + "x-migrations-table=" + _DBOS_MIGRATION_TABLE + // Create migrator m, err := migrate.NewWithSourceInstance("iofs", d, databaseURL) if err != nil { @@ -646,12 +662,12 @@ func (s *systemDatabase) AwaitWorkflowResult(ctx context.Context, workflowID str } type recordOperationResultDBInput struct { - workflowID string - operationID int - operationName string - output any - err error - tx pgx.Tx + workflowID string + stepID int + stepName string + output any + err error + tx pgx.Tx } func (s *systemDatabase) RecordOperationResult(ctx context.Context, input recordOperationResultDBInput) error { @@ -675,18 +691,18 @@ func (s *systemDatabase) RecordOperationResult(ctx context.Context, input record if input.tx != nil { commandTag, err = input.tx.Exec(ctx, query, input.workflowID, - input.operationID, + input.stepID, outputString, errorString, - input.operationName, + input.stepName, ) } else { commandTag, err = s.pool.Exec(ctx, query, input.workflowID, - input.operationID, + input.stepID, outputString, errorString, - input.operationName, + input.stepName, ) } @@ -716,8 +732,8 @@ func (s *systemDatabase) RecordOperationResult(ctx context.Context, input record type recordChildWorkflowDBInput struct { parentWorkflowID string childWorkflowID string - functionID int - functionName string + stepID int + stepName string tx pgx.Tx } @@ -732,15 +748,15 @@ func (s *systemDatabase) RecordChildWorkflow(ctx context.Context, input recordCh if input.tx != nil { commandTag, err = input.tx.Exec(ctx, query, input.parentWorkflowID, - input.functionID, - input.functionName, + input.stepID, + input.stepName, input.childWorkflowID, ) } else { commandTag, err = s.pool.Exec(ctx, query, input.parentWorkflowID, - input.functionID, - input.functionName, + input.stepID, + input.stepName, input.childWorkflowID, ) } @@ -750,7 +766,7 @@ func (s *systemDatabase) RecordChildWorkflow(ctx context.Context, input recordCh if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { return fmt.Errorf( "child workflow %s already registered for parent workflow %s (operation ID: %d)", - input.childWorkflowID, input.parentWorkflowID, input.functionID) + input.childWorkflowID, input.parentWorkflowID, input.stepID) } return fmt.Errorf("failed to record child workflow: %w", err) } @@ -782,7 +798,7 @@ func (s *systemDatabase) CheckChildWorkflow(ctx context.Context, workflowID stri type recordChildGetResultDBInput struct { parentWorkflowID string childWorkflowID string - operationID int + stepID int output string err error } @@ -801,7 +817,7 @@ func (s *systemDatabase) RecordChildGetResult(ctx context.Context, input recordC _, err := s.pool.Exec(ctx, query, input.parentWorkflowID, - input.operationID, + input.stepID, "DBOS.getResult", input.output, errorString, @@ -823,10 +839,10 @@ type recordedResult struct { } type checkOperationExecutionDBInput struct { - workflowID string - operationID int - functionName string - tx pgx.Tx + workflowID string + stepID int + stepName string + tx pgx.Tx } func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) { @@ -848,7 +864,7 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input chec workflowStatusQuery := `SELECT status FROM dbos.workflow_status WHERE workflow_uuid = $1` // Second query: Retrieve operation outputs if they exist - operationOutputQuery := `SELECT output, error, function_name + stepOutputQuery := `SELECT output, error, function_name FROM dbos.operation_outputs WHERE workflow_uuid = $1 AND function_id = $2` @@ -873,7 +889,7 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input chec var errorStr *string var recordedFunctionName string - err = tx.QueryRow(ctx, operationOutputQuery, input.workflowID, input.operationID).Scan(&outputString, &errorStr, &recordedFunctionName) + err = tx.QueryRow(ctx, stepOutputQuery, input.workflowID, input.stepID).Scan(&outputString, &errorStr, &recordedFunctionName) // If there are no operation outputs, return nil if err != nil { @@ -884,8 +900,8 @@ func (s *systemDatabase) CheckOperationExecution(ctx context.Context, input chec } // If the provided and recorded function name are different, throw an exception - if input.functionName != recordedFunctionName { - return nil, newUnexpectedStepError(input.workflowID, input.operationID, input.functionName, recordedFunctionName) + if input.stepName != recordedFunctionName { + return nil, newUnexpectedStepError(input.workflowID, input.stepID, input.stepName, recordedFunctionName) } output, err := deserialize(outputString) @@ -1009,16 +1025,16 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro functionName := "DBOS.send" // Get workflow state from context - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) - if !ok || workflowState == nil { + wfState, ok := ctx.Value(workflowStateKey).(*workflowState) + if !ok || wfState == nil { return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") } - if workflowState.isWithinStep { - return newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call Send within a step") + if wfState.isWithinStep { + return newStepExecutionError(wfState.workflowID, functionName, "cannot call Send within a step") } - stepID := workflowState.NextStepID() + stepID := wfState.NextStepID() tx, err := s.pool.Begin(ctx) if err != nil { @@ -1028,10 +1044,10 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro // Check if operation was already executed and do nothing if so checkInput := checkOperationExecutionDBInput{ - workflowID: workflowState.WorkflowID, - operationID: stepID, - functionName: functionName, - tx: tx, + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, + tx: tx, } recordedResult, err := s.CheckOperationExecution(ctx, checkInput) if err != nil { @@ -1068,12 +1084,12 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro // Record the operation result recordInput := recordOperationResultDBInput{ - workflowID: workflowState.WorkflowID, - operationID: stepID, - operationName: functionName, - output: nil, - err: nil, - tx: tx, + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, + output: nil, + err: nil, + tx: tx, } err = s.RecordOperationResult(ctx, recordInput) @@ -1095,17 +1111,17 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any // Get workflow state from context // XXX these checks might be better suited for outside of the system db code. We'll see when we implement the client. - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) - if !ok || workflowState == nil { + wfState, ok := ctx.Value(workflowStateKey).(*workflowState) + if !ok || wfState == nil { return nil, newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") } - if workflowState.isWithinStep { - return nil, newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call Recv within a step") + if wfState.isWithinStep { + return nil, newStepExecutionError(wfState.workflowID, functionName, "cannot call Recv within a step") } - stepID := workflowState.NextStepID() - destinationID := workflowState.WorkflowID + stepID := wfState.NextStepID() + destinationID := wfState.workflowID // Set default topic if not provided topic := _DBOS_NULL_TOPIC @@ -1116,9 +1132,9 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any // Check if operation was already executed // XXX this might not need to be in the transaction checkInput := checkOperationExecutionDBInput{ - workflowID: destinationID, - operationID: stepID, - functionName: functionName, + workflowID: destinationID, + stepID: stepID, + stepName: functionName, } recordedResult, err := s.CheckOperationExecution(ctx, checkInput) if err != nil { @@ -1217,11 +1233,11 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any // Record the operation result recordInput := recordOperationResultDBInput{ - workflowID: destinationID, - operationID: stepID, - operationName: functionName, - output: message, - tx: tx, + workflowID: destinationID, + stepID: stepID, + stepName: functionName, + output: message, + tx: tx, } err = s.RecordOperationResult(ctx, recordInput) if err != nil { @@ -1239,16 +1255,16 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInp functionName := "DBOS.setEvent" // Get workflow state from context - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) - if !ok || workflowState == nil { + wfState, ok := ctx.Value(workflowStateKey).(*workflowState) + if !ok || wfState == nil { return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") } - if workflowState.isWithinStep { - return newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call SetEvent within a step") + if wfState.isWithinStep { + return newStepExecutionError(wfState.workflowID, functionName, "cannot call SetEvent within a step") } - stepID := workflowState.NextStepID() + stepID := wfState.NextStepID() tx, err := s.pool.Begin(ctx) if err != nil { @@ -1258,10 +1274,10 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInp // Check if operation was already executed and do nothing if so checkInput := checkOperationExecutionDBInput{ - workflowID: workflowState.WorkflowID, - operationID: stepID, - functionName: functionName, - tx: tx, + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, + tx: tx, } recordedResult, err := s.CheckOperationExecution(ctx, checkInput) if err != nil { @@ -1284,19 +1300,19 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInp ON CONFLICT (workflow_uuid, key) DO UPDATE SET value = EXCLUDED.value` - _, err = tx.Exec(ctx, insertQuery, workflowState.WorkflowID, input.Key, messageString) + _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, messageString) if err != nil { return fmt.Errorf("failed to insert/update workflow event: %w", err) } // Record the operation result recordInput := recordOperationResultDBInput{ - workflowID: workflowState.WorkflowID, - operationID: stepID, - operationName: functionName, - output: nil, - err: nil, - tx: tx, + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, + output: nil, + err: nil, + tx: tx, } err = s.RecordOperationResult(ctx, recordInput) @@ -1316,22 +1332,22 @@ func (s *systemDatabase) GetEvent(ctx context.Context, input WorkflowGetEventInp functionName := "DBOS.getEvent" // Get workflow state from context (optional for GetEvent as we can get an event from outside a workflow) - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) + wfState, ok := ctx.Value(workflowStateKey).(*workflowState) var stepID int var isInWorkflow bool - if ok && workflowState != nil { + if ok && wfState != nil { isInWorkflow = true - if workflowState.isWithinStep { - return nil, newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call GetEvent within a step") + if wfState.isWithinStep { + return nil, newStepExecutionError(wfState.workflowID, functionName, "cannot call GetEvent within a step") } - stepID = workflowState.NextStepID() + stepID = wfState.NextStepID() // Check if operation was already executed (only if in workflow) checkInput := checkOperationExecutionDBInput{ - workflowID: workflowState.WorkflowID, - operationID: stepID, - functionName: functionName, + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, } recordedResult, err := s.CheckOperationExecution(ctx, checkInput) if err != nil { @@ -1414,11 +1430,11 @@ func (s *systemDatabase) GetEvent(ctx context.Context, input WorkflowGetEventInp // Record the operation result if this is called within a workflow if isInWorkflow { recordInput := recordOperationResultDBInput{ - workflowID: workflowState.WorkflowID, - operationID: stepID, - operationName: functionName, - output: value, - err: nil, + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, + output: value, + err: nil, } err = s.RecordOperationResult(ctx, recordInput) diff --git a/dbos/workflow.go b/dbos/workflow.go index fc696555..51f6b1ba 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -54,16 +54,16 @@ type WorkflowStatus struct { Priority int `json:"priority"` } -// WorkflowState holds the runtime state for a workflow execution +// workflowState holds the runtime state for a workflow execution // TODO: this should be an internal type. Workflows should have aptly named getters to access the state -type WorkflowState struct { - WorkflowID string +type workflowState struct { + workflowID string stepCounter int isWithinStep bool } // NextStepID returns the next step ID and increments the counter -func (ws *WorkflowState) NextStepID() int { +func (ws *workflowState) NextStepID() int { ws.stepCounter++ return ws.stepCounter } @@ -98,17 +98,17 @@ func (h *workflowHandle[R]) GetResult(ctx context.Context) (R, error) { return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?") } // If we are calling GetResult inside a workflow, record the result as a step result - parentWorkflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) + parentWorkflowState, ok := ctx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil if isChildWorkflow { encodedOutput, encErr := serialize(outcome.result) if encErr != nil { - return *new(R), newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) + return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } recordGetResultInput := recordChildGetResultDBInput{ - parentWorkflowID: parentWorkflowState.WorkflowID, + parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: h.workflowID, - operationID: parentWorkflowState.NextStepID(), + stepID: parentWorkflowState.NextStepID(), output: encodedOutput, err: outcome.err, } @@ -153,17 +153,17 @@ func (h *workflowPollingHandle[R]) GetResult(ctx context.Context) (R, error) { return *new(R), newWorkflowUnexpectedResultType(h.workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", result)) } // If we are calling GetResult inside a workflow, record the result as a step result - parentWorkflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) + parentWorkflowState, ok := ctx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil if isChildWorkflow { encodedOutput, encErr := serialize(typedResult) if encErr != nil { - return *new(R), newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) + return *new(R), newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("serializing child workflow result: %v", encErr)) } recordGetResultInput := recordChildGetResultDBInput{ - parentWorkflowID: parentWorkflowState.WorkflowID, + parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: h.workflowID, - operationID: parentWorkflowState.NextStepID(), + stepID: parentWorkflowState.NextStepID(), output: encodedOutput, err: err, } @@ -200,10 +200,10 @@ func (h *workflowPollingHandle[R]) GetWorkflowID() string { /**********************************/ /******* WORKFLOW REGISTRY *******/ /**********************************/ -type TypedErasedWorkflowWrapperFunc func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) +type typedErasedWorkflowWrapperFunc func(ctx context.Context, input any, opts ...workflowOption) (WorkflowHandle[any], error) type workflowRegistryEntry struct { - wrappedFunction TypedErasedWorkflowWrapperFunc + wrappedFunction typedErasedWorkflowWrapperFunc maxRetries int } @@ -211,7 +211,7 @@ var registry = make(map[string]workflowRegistryEntry) var regMutex sync.RWMutex // Register adds a workflow function to the registry (thread-safe, only once per name) -func registerWorkflow(fqn string, fn TypedErasedWorkflowWrapperFunc, maxRetries int) { +func registerWorkflow(fqn string, fn typedErasedWorkflowWrapperFunc, maxRetries int) { regMutex.Lock() defer regMutex.Unlock() @@ -334,7 +334,7 @@ func WithWorkflow[P any, R any](fn WorkflowFunc[P, R], opts ...workflowRegistrat type contextKey string // TODO this should be a private type, once we have proper getter for a workflow state -const WorkflowStateKey contextKey = "workflowState" +const workflowStateKey contextKey = "workflowState" type WorkflowFunc[P any, R any] func(ctx context.Context, input P) (R, error) type WorkflowWrapperFunc[P any, R any] func(ctx context.Context, input P, opts ...workflowOption) (WorkflowHandle[R], error) @@ -399,7 +399,7 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp dbosWorkflowContext := context.Background() // Check if we are within a workflow (and thus a child workflow) - parentWorkflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) + parentWorkflowState, ok := ctx.Value(workflowStateKey).(*workflowState) isChildWorkflow := ok && parentWorkflowState != nil // TODO Check if cancelled @@ -409,7 +409,7 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp if params.workflowID == "" { if isChildWorkflow { stepID := parentWorkflowState.NextStepID() - workflowID = fmt.Sprintf("%s-%d", parentWorkflowState.WorkflowID, stepID) + workflowID = fmt.Sprintf("%s-%d", parentWorkflowState.workflowID, stepID) } else { workflowID = uuid.New().String() } @@ -419,9 +419,9 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp // If this is a child workflow that has already been recorded in operations_output, return directly a polling handle if isChildWorkflow { - childWorkflowID, err := dbos.systemDB.CheckChildWorkflow(dbosWorkflowContext, parentWorkflowState.WorkflowID, parentWorkflowState.stepCounter) + childWorkflowID, err := dbos.systemDB.CheckChildWorkflow(dbosWorkflowContext, parentWorkflowState.workflowID, parentWorkflowState.stepCounter) if err != nil { - return nil, newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("checking child workflow: %v", err)) + return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("checking child workflow: %v", err)) } if childWorkflowID != nil { return &workflowPollingHandle[R]{workflowID: *childWorkflowID}, nil @@ -481,15 +481,15 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp // Get the step ID that was used for generating the child workflow ID stepID := parentWorkflowState.stepCounter childInput := recordChildWorkflowDBInput{ - parentWorkflowID: parentWorkflowState.WorkflowID, + parentWorkflowID: parentWorkflowState.workflowID, childWorkflowID: workflowStatus.ID, - functionName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // Will need to test this - functionID: stepID, + stepName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), // Will need to test this + stepID: stepID, tx: tx, } err = dbos.systemDB.RecordChildWorkflow(dbosWorkflowContext, childInput) if err != nil { - return nil, newWorkflowExecutionError(parentWorkflowState.WorkflowID, fmt.Sprintf("recording child workflow: %v", err)) + return nil, newWorkflowExecutionError(parentWorkflowState.workflowID, fmt.Sprintf("recording child workflow: %v", err)) } } @@ -510,13 +510,13 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp } // Create workflow state to track step execution - workflowState := &WorkflowState{ - WorkflowID: workflowStatus.ID, + wfState := &workflowState{ + workflowID: workflowStatus.ID, stepCounter: -1, } // Run the function in a goroutine - augmentUserContext := context.WithValue(ctx, WorkflowStateKey, workflowState) + augmentUserContext := context.WithValue(ctx, workflowStateKey, wfState) go func() { result, err := fn(augmentUserContext, input) status := WorkflowStatusSuccess @@ -600,7 +600,7 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op return *new(R), newStepExecutionError("", "", "step function cannot be nil") } - operationName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() // Apply options to build params with defaults params := StepParams{ @@ -614,39 +614,39 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op } // Get workflow state from context - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) - if !ok || workflowState == nil { - return *new(R), newStepExecutionError("", operationName, "workflow state not found in context: are you running this step within a workflow?") + wfState, ok := ctx.Value(workflowStateKey).(*workflowState) + if !ok || wfState == nil { + return *new(R), newStepExecutionError("", stepName, "workflow state not found in context: are you running this step within a workflow?") } // If within a step, just run the function directly - if workflowState.isWithinStep { + if wfState.isWithinStep { return fn(ctx, input) } // Get next step ID - operationID := workflowState.NextStepID() + stepID := wfState.NextStepID() // Check the step is cancelled, has already completed, or is called with a different name recordedOutput, err := dbos.systemDB.CheckOperationExecution(ctx, checkOperationExecutionDBInput{ - workflowID: workflowState.WorkflowID, - operationID: operationID, - functionName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), + workflowID: wfState.workflowID, + stepID: stepID, + stepName: runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), }) if err != nil { - return *new(R), newStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("checking operation execution: %v", err)) + return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("checking operation execution: %v", err)) } if recordedOutput != nil { return recordedOutput.output.(R), recordedOutput.err } // Execute step with retry logic if MaxRetries > 0 - stepState := WorkflowState{ - WorkflowID: workflowState.WorkflowID, - stepCounter: workflowState.stepCounter, + stepState := workflowState{ + workflowID: wfState.workflowID, + stepCounter: wfState.stepCounter, isWithinStep: true, } - stepCtx := context.WithValue(ctx, WorkflowStateKey, &stepState) + stepCtx := context.WithValue(ctx, workflowStateKey, &stepState) stepOutput, stepError := fn(stepCtx, input) @@ -663,12 +663,12 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op delay = time.Duration(math.Min(exponentialDelay, float64(params.MaxInterval))) } - getLogger().Error("step failed, retrying", "step_name", operationName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError) + getLogger().Error("step failed, retrying", "step_name", stepName, "retry", retry, "max_retries", params.MaxRetries, "delay", delay, "error", stepError) // Wait before retry select { case <-ctx.Done(): - return *new(R), newStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("context cancelled during retry: %v", ctx.Err())) + return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("context cancelled during retry: %v", ctx.Err())) case <-time.After(delay): // Continue to retry } @@ -686,7 +686,7 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // If max retries reached, create MaxStepRetriesExceeded error if retry == params.MaxRetries { - stepError = newMaxStepRetriesExceededError(workflowState.WorkflowID, operationName, params.MaxRetries, joinedErrors) + stepError = newMaxStepRetriesExceededError(wfState.workflowID, stepName, params.MaxRetries, joinedErrors) break } } @@ -694,15 +694,15 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op // Record the final result dbInput := recordOperationResultDBInput{ - workflowID: workflowState.WorkflowID, - operationName: operationName, - operationID: operationID, - err: stepError, - output: stepOutput, + workflowID: wfState.workflowID, + stepName: stepName, + stepID: stepID, + err: stepError, + output: stepOutput, } recErr := dbos.systemDB.RecordOperationResult(ctx, dbInput) if recErr != nil { - return *new(R), newStepExecutionError(workflowState.WorkflowID, operationName, fmt.Sprintf("recording step outcome: %v", recErr)) + return *new(R), newStepExecutionError(wfState.workflowID, stepName, fmt.Sprintf("recording step outcome: %v", recErr)) } return stepOutput, stepError @@ -779,6 +779,15 @@ func GetEvent[R any](ctx context.Context, input WorkflowGetEventInput) (R, error /******* WORKFLOW MANAGEMENT *******/ /***********************************/ +// GetWorkflowID retrieves the workflow ID from the context if called within a DBOS workflow +func GetWorkflowID(ctx context.Context) (string, error) { + wfState, ok := ctx.Value(workflowStateKey).(*workflowState) + if !ok || wfState == nil { + return "", errors.New("not within a DBOS workflow context") + } + return wfState.workflowID, nil +} + func RetrieveWorkflow[R any](workflowID string) (workflowPollingHandle[R], error) { ctx := context.Background() workflowStatus, err := dbos.systemDB.ListWorkflows(ctx, listWorkflowsDBInput{ diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index e8d678ea..63ff0210 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -312,7 +312,6 @@ var stepIdempotencyCounter int func stepIdempotencyTest(ctx context.Context, input string) (string, error) { stepIdempotencyCounter++ - fmt.Println("Executing idempotency step:", stepIdempotencyCounter) return input, nil } @@ -455,24 +454,22 @@ func TestSteps(t *testing.T) { var ( childWf = WithWorkflow(func(ctx context.Context, i int) (string, error) { - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) - if !ok { - return "", fmt.Errorf("workflow state not found in context") + workflowID, err := GetWorkflowID(ctx) + if err != nil { + return "", fmt.Errorf("failed to get workflow ID: %v", err) } - fmt.Println("childWf workflow state:", workflowState) - expectedCurrentID := fmt.Sprintf("%s-%d", workflowState.WorkflowID, i) - if workflowState.WorkflowID != expectedCurrentID { - return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedCurrentID, workflowState.WorkflowID) + expectedCurrentID := fmt.Sprintf("%s-%d", workflowID, i) + if workflowID != expectedCurrentID { + return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedCurrentID, workflowID) } // XXX right now the steps of a child workflow start with an incremented step ID, because the first step ID is allocated to the child workflow return RunAsStep(ctx, simpleStep, "") }) parentWf = WithWorkflow(func(ctx context.Context, i int) (string, error) { - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) - if !ok { - return "", fmt.Errorf("workflow state not found in context") + workflowID, err := GetWorkflowID(ctx) + if err != nil { + return "", fmt.Errorf("failed to get workflow ID: %v", err) } - fmt.Println("parentWf workflow state:", workflowState) childHandle, err := childWf(ctx, i) if err != nil { @@ -480,14 +477,14 @@ var ( } // Check this wf ID is built correctly - expectedParentID := fmt.Sprintf("%s-%d", workflowState.WorkflowID, i) - if workflowState.WorkflowID != expectedParentID { - return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedParentID, workflowState.WorkflowID) + expectedParentID := fmt.Sprintf("%s-%d", workflowID, i) + if workflowID != expectedParentID { + return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedParentID, workflowID) } // Verify child workflow ID follows the pattern: parentID-functionID childWorkflowID := childHandle.GetWorkflowID() - expectedChildID := fmt.Sprintf("%s-%d", workflowState.WorkflowID, i) + expectedChildID := fmt.Sprintf("%s-%d", workflowID, i) if childWorkflowID != expectedChildID { return "", fmt.Errorf("expected childWf ID to be %s, got %s", expectedChildID, childWorkflowID) } @@ -495,11 +492,10 @@ var ( }) grandParentWf = WithWorkflow(func(ctx context.Context, _ string) (string, error) { for i := range 3 { - workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState) - if !ok { - return "", fmt.Errorf("workflow state not found in context") + workflowID, err := GetWorkflowID(ctx) + if err != nil { + return "", fmt.Errorf("failed to get workflow ID: %v", err) } - fmt.Println("grandParentWf workflow state:", workflowState) childHandle, err := parentWf(ctx, i) if err != nil { @@ -507,14 +503,14 @@ var ( } // The handle should a direct handle - _, ok = childHandle.(*workflowHandle[string]) + _, ok := childHandle.(*workflowHandle[string]) if !ok { return "", fmt.Errorf("expected childHandle to be of type *workflowHandle[string], got %T", childHandle) } // Verify child workflow ID follows the pattern: parentID-functionID childWorkflowID := childHandle.GetWorkflowID() - expectedPrefix := fmt.Sprintf("%s-%d", workflowState.WorkflowID, i) + expectedPrefix := fmt.Sprintf("%s-%d", workflowID, i) if childWorkflowID != expectedPrefix { return "", fmt.Errorf("expected parentWf workflow ID to be %s, got %s", expectedPrefix, childWorkflowID) } @@ -894,7 +890,6 @@ var ( counter1Ch = make(chan time.Time, 100) _ = WithWorkflow(func(ctx context.Context, scheduledTime time.Time) (string, error) { startTime := time.Now() - // fmt.Println("scheduled time:", scheduledTime, "current time:", startTime) counter++ if counter == 10 { return "", fmt.Errorf("counter reached 100, stopping workflow") @@ -989,7 +984,6 @@ type sendWorkflowInput struct { } func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - fmt.Println("Starting send workflow with input:", input) err := Send(ctx, WorkflowSendInput{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message1"}) if err != nil { return "", err @@ -1002,7 +996,6 @@ func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) if err != nil { return "", err } - fmt.Println("Sending message on topic:", input.Topic, "to destination:", input.DestinationID) return "", nil } @@ -1035,7 +1028,6 @@ func receiveWorkflowCoordinated(ctx context.Context, input struct { // Do a single Recv call with timeout msg, err := Recv[string](ctx, WorkflowRecvInput{Topic: input.Topic, Timeout: 3 * time.Second}) - fmt.Println(err) if err != nil { return "", err }