Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,13 @@ func Launch() error {
}

// Run a round of recovery on the local executor
_, err := recoverPendingWorkflows(context.Background(), []string{_EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it?
recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{_EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it?
if err != nil {
return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err))
}
if len(recoveryHandles) > 0 {
logger.Info("Recovered pending workflows", "count", len(recoveryHandles))
}

logger.Info("DBOS initialized", "app_version", _APP_VERSION, "executor_id", _EXECUTOR_ID)
return nil
Expand Down
4 changes: 3 additions & 1 deletion dbos/migrations/000001_initial_dbos_schema.down.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
-- 001_initial_dbos_schema.down.sql

-- Drop trigger first
-- Drop triggers first
DROP TRIGGER IF EXISTS dbos_notifications_trigger ON dbos.notifications;
DROP TRIGGER IF EXISTS dbos_workflow_events_trigger ON dbos.workflow_events;

-- Drop function
DROP FUNCTION IF EXISTS dbos.notifications_function();
DROP FUNCTION IF EXISTS dbos.workflow_events_function();

-- Drop tables in reverse order to respect foreign key constraints
DROP TABLE IF EXISTS dbos.workflow_events;
Expand Down
17 changes: 16 additions & 1 deletion dbos/migrations/000001_initial_dbos_schema.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,19 @@ CREATE TABLE dbos.workflow_events (
PRIMARY KEY (workflow_uuid, key),
FOREIGN KEY (workflow_uuid) REFERENCES dbos.workflow_status(workflow_uuid)
ON UPDATE CASCADE ON DELETE CASCADE
);
);

-- Create events function
CREATE OR REPLACE FUNCTION dbos.workflow_events_function() RETURNS TRIGGER AS $$
DECLARE
payload text := NEW.workflow_uuid || '::' || NEW.key;
BEGIN
PERFORM pg_notify('dbos_workflow_events_channel', payload);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

-- Create events trigger
CREATE TRIGGER dbos_workflow_events_trigger
AFTER INSERT ON dbos.workflow_events
FOR EACH ROW EXECUTE FUNCTION dbos.workflow_events_function();
259 changes: 229 additions & 30 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type SystemDatabase interface {
GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error)
Send(ctx context.Context, input WorkflowSendInput) error
Recv(ctx context.Context, input WorkflowRecvInput) (any, error)
SetEvent(ctx context.Context, input WorkflowSetEventInput) error
GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error)
}

type systemDatabase struct {
Expand Down Expand Up @@ -157,15 +159,11 @@ func NewSystemDatabase(databaseURL string) (SystemDatabase, error) {
return nil, fmt.Errorf("failed to parse database URL: %v", err)
}
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
if n.Channel == "dbos_notifications_channel" {
if n.Channel == "dbos_notifications_channel" || n.Channel == "dbos_workflow_events_channel" {
// Check if an entry exists in the map, indexed by the payload
// If yes, send a signal to the channel so the listener can wake up
if ch, exists := notificationsMap.Load(n.Payload); exists {
select {
case ch.(chan bool) <- true: // Send a signal to wake up the listener
default:
getLogger().Warn("notification channel for payload is full, skipping", "payload", n.Payload)
}
// If yes, broadcast on the condition variable so listeners can wake up
if cond, exists := notificationsMap.Load(n.Payload); exists {
cond.(*sync.Cond).Broadcast()
}
}
}
Expand Down Expand Up @@ -971,17 +969,17 @@ func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string
/****************************************/

func (s *systemDatabase) notificationListenerLoop(ctx context.Context) {
mrr := s.notificationListenerConnection.Exec(ctx, "LISTEN dbos_notifications_channel")
mrr := s.notificationListenerConnection.Exec(ctx, "LISTEN dbos_notifications_channel; LISTEN dbos_workflow_events_channel")
results, err := mrr.ReadAll()
if err != nil {
getLogger().Error("Failed to listen on dbos_notifications_channel", "error", err)
getLogger().Error("Failed to listen on notification channels", "error", err)
return
}
mrr.Close()

for _, result := range results {
if result.Err != nil {
getLogger().Error("Error listening on dbos_notifications_channel", "error", result.Err)
getLogger().Error("Error listening on notification channels", "error", result.Err)
return
}
}
Expand Down Expand Up @@ -1040,6 +1038,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro
return err
}
if recordedResult != nil {
// when hitting this case, recordedResult will be &{<nil> <nil>}
return nil
}

Expand Down Expand Up @@ -1114,19 +1113,12 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any
topic = input.Topic
}

tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)

// Check if operation was already executed
// XXX this might not need to be in the transaction
checkInput := checkOperationExecutionDBInput{
workflowID: destinationID,
operationID: stepID,
functionName: functionName,
tx: tx,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
Expand All @@ -1141,42 +1133,54 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any

// First check if there's already a receiver for this workflow/topic to avoid unnecessary database load
payload := fmt.Sprintf("%s::%s", destinationID, topic)
c := make(chan bool, 1) // Make it buffered to allow the notification listener to post a signal even if the receiver has not reached its select statement yet
_, loaded := s.notificationsMap.LoadOrStore(payload, c)
cond := sync.NewCond(&sync.Mutex{})
_, loaded := s.notificationsMap.LoadOrStore(payload, cond)
if loaded {
close(c)
getLogger().Error("Receive already called for workflow", "destination_id", destinationID)
return nil, newWorkflowConflictIDError(destinationID)
}
defer func() {
// Clean up the channel after we're done
// Clean up the condition variable after we're done and broadcast to wake up any waiting goroutines
// XXX We should handle panics in this function and make sure we call this. Not a problem for now as panic will crash the importing package.
cond.Broadcast()
s.notificationsMap.Delete(payload)
close(c)
}()

// Now check if there is already a message available in the database.
// If not, we'll wait for a notification and timeout
var exists bool
query := `SELECT EXISTS (SELECT 1 FROM dbos.notifications WHERE destination_uuid = $1 AND topic = $2)`
err = tx.QueryRow(ctx, query, destinationID, topic).Scan(&exists)
err = s.pool.QueryRow(ctx, query, destinationID, topic).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check message: %w", err)
}
if !exists {
// Listen for notifications on the channel
// Wait for notifications using condition variable with timeout pattern
// XXX should we prevent zero or negative timeouts?
getLogger().Debug("Waiting for notification on channel", "payload", payload)
getLogger().Debug("Waiting for notification on condition variable", "payload", payload)

done := make(chan struct{})
go func() {
cond.L.Lock()
defer cond.L.Unlock()
cond.Wait()
close(done)
}()

select {
case <-c:
getLogger().Debug("Received notification on channel", "payload", payload)
case <-done:
getLogger().Debug("Received notification on condition variable", "payload", payload)
case <-time.After(input.Timeout):
// If we reach the timeout, we can check if there is a message in the database, and if not return nil.
getLogger().Warn("Timeout reached for channel", "payload", payload, "timeout", input.Timeout)
getLogger().Warn("Recv() timeout reached", "payload", payload, "timeout", input.Timeout)
}
}

// Find the oldest message and delete it atomically
tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
Comment on lines +1179 to +1183
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transaction should start for the sequence:

  • consume message
  • record Recv() output

But not before

query = `
WITH oldest_entry AS (
SELECT destination_uuid, topic, message, created_at_epoch_ms
Expand Down Expand Up @@ -1231,6 +1235,201 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any
return message, nil
}

func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInput) error {
functionName := "DBOS.setEvent"

// Get workflow state from context
workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState)
if !ok || workflowState == nil {
return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?")
}

if workflowState.isWithinStep {
return newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call SetEvent within a step")
}

stepID := workflowState.NextStepID()

tx, err := s.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)

// Check if operation was already executed and do nothing if so
checkInput := checkOperationExecutionDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
functionName: functionName,
tx: tx,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
return err
}
if recordedResult != nil {
// when hitting this case, recordedResult will be &{<nil> <nil>}
return nil
}

// Serialize the message. It must have been registered with encoding/gob by the user if not a basic type.
messageString, err := serialize(input.Message)
if err != nil {
return fmt.Errorf("failed to serialize message: %w", err)
}

// Insert or update the event using UPSERT
insertQuery := `INSERT INTO dbos.workflow_events (workflow_uuid, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (workflow_uuid, key)
DO UPDATE SET value = EXCLUDED.value`

_, err = tx.Exec(ctx, insertQuery, workflowState.WorkflowID, input.Key, messageString)
if err != nil {
return fmt.Errorf("failed to insert/update workflow event: %w", err)
}

// Record the operation result
recordInput := recordOperationResultDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
operationName: functionName,
output: nil,
err: nil,
tx: tx,
}

err = s.RecordOperationResult(ctx, recordInput)
if err != nil {
return fmt.Errorf("failed to record operation result: %w", err)
}

// Commit transaction
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}

return nil
}

func (s *systemDatabase) GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error) {
functionName := "DBOS.getEvent"

// Get workflow state from context (optional for GetEvent as we can get an event from outside a workflow)
workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState)
var stepID int
var isInWorkflow bool

if ok && workflowState != nil {
isInWorkflow = true
if workflowState.isWithinStep {
return nil, newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call GetEvent within a step")
}
stepID = workflowState.NextStepID()

// Check if operation was already executed (only if in workflow)
checkInput := checkOperationExecutionDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
functionName: functionName,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
return nil, err
}
if recordedResult != nil {
return recordedResult.output, recordedResult.err
}
}

// Create notification payload and condition variable
payload := fmt.Sprintf("%s::%s", input.TargetWorkflowID, input.Key)
cond := sync.NewCond(&sync.Mutex{})
existingCond, loaded := s.notificationsMap.LoadOrStore(payload, cond)
if loaded {
// Reuse the existing condition variable
cond = existingCond.(*sync.Cond)
}

// Defer broadcast to ensure any waiting goroutines eventually unlock
defer func() {
cond.Broadcast()
// Clean up the condition variable after we're done, if we created it
if !loaded {
s.notificationsMap.Delete(payload)
}
}()

// Check if the event already exists in the database
query := `SELECT value FROM dbos.workflow_events WHERE workflow_uuid = $1 AND key = $2`
var value any
var valueString *string

row := s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
err := row.Scan(&valueString)
if err != nil && err != pgx.ErrNoRows {
return nil, fmt.Errorf("failed to query workflow event: %w", err)
}

if err == pgx.ErrNoRows || valueString == nil { // XXX valueString should never be `nil`
// Wait for notification with timeout using condition variable
done := make(chan struct{})
go func() {
cond.L.Lock()
defer cond.L.Unlock()
cond.Wait()
close(done)
}()

select {
case <-done:
// Received notification
case <-time.After(input.Timeout):
// Timeout reached
getLogger().Warn("GetEvent() timeout reached", "target_workflow_id", input.TargetWorkflowID, "key", input.Key, "timeout", input.Timeout)
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled while waiting for event: %w", ctx.Err())
}

// Query the database again after waiting
row = s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
err = row.Scan(&valueString)
if err != nil {
if err == pgx.ErrNoRows {
value = nil // Event still doesn't exist
} else {
return nil, fmt.Errorf("failed to query workflow event after wait: %w", err)
}
}
}

// Deserialize the value if it exists
if valueString != nil {
value, err = deserialize(valueString)
if err != nil {
return nil, fmt.Errorf("failed to deserialize event value: %w", err)
}
}

// Record the operation result if this is called within a workflow
if isInWorkflow {
recordInput := recordOperationResultDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
operationName: functionName,
output: value,
err: nil,
}

err = s.RecordOperationResult(ctx, recordInput)
if err != nil {
return nil, fmt.Errorf("failed to record operation result: %w", err)
}
}

return value, nil
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trust recording the step outcome + reading the event do not need being done transactionally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they don't, as it doesn't consume the event

/*******************************/
/******* QUEUES ********/
/*******************************/
Expand Down
Loading
Loading