Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,13 @@ func Launch() error {
}

// Run a round of recovery on the local executor
_, err := recoverPendingWorkflows(context.Background(), []string{_EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it?
recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{_EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it?
if err != nil {
return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err))
}
if len(recoveryHandles) > 0 {
logger.Info("Recovered pending workflows", "count", len(recoveryHandles))
}

logger.Info("DBOS initialized", "app_version", _APP_VERSION, "executor_id", _EXECUTOR_ID)
return nil
Expand Down
4 changes: 3 additions & 1 deletion dbos/migrations/000001_initial_dbos_schema.down.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
-- 001_initial_dbos_schema.down.sql

-- Drop trigger first
-- Drop triggers first
DROP TRIGGER IF EXISTS dbos_notifications_trigger ON dbos.notifications;
DROP TRIGGER IF EXISTS dbos_workflow_events_trigger ON dbos.workflow_events;

-- Drop function
DROP FUNCTION IF EXISTS dbos.notifications_function();
DROP FUNCTION IF EXISTS dbos.workflow_events_function();

-- Drop tables in reverse order to respect foreign key constraints
DROP TABLE IF EXISTS dbos.workflow_events;
Expand Down
17 changes: 16 additions & 1 deletion dbos/migrations/000001_initial_dbos_schema.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,19 @@ CREATE TABLE dbos.workflow_events (
PRIMARY KEY (workflow_uuid, key),
FOREIGN KEY (workflow_uuid) REFERENCES dbos.workflow_status(workflow_uuid)
ON UPDATE CASCADE ON DELETE CASCADE
);
);

-- Create events function
CREATE OR REPLACE FUNCTION dbos.workflow_events_function() RETURNS TRIGGER AS $$
DECLARE
payload text := NEW.workflow_uuid || '::' || NEW.key;
BEGIN
PERFORM pg_notify('dbos_workflow_events_channel', payload);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

-- Create events trigger
CREATE TRIGGER dbos_workflow_events_trigger
AFTER INSERT ON dbos.workflow_events
FOR EACH ROW EXECUTE FUNCTION dbos.workflow_events_function();
257 changes: 227 additions & 30 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type SystemDatabase interface {
GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error)
Send(ctx context.Context, input WorkflowSendInput) error
Recv(ctx context.Context, input WorkflowRecvInput) (any, error)
SetEvent(ctx context.Context, input WorkflowSetEventInput) error
GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error)
}

type systemDatabase struct {
Expand Down Expand Up @@ -157,15 +159,11 @@ func NewSystemDatabase(databaseURL string) (SystemDatabase, error) {
return nil, fmt.Errorf("failed to parse database URL: %v", err)
}
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
if n.Channel == "dbos_notifications_channel" {
if n.Channel == "dbos_notifications_channel" || n.Channel == "dbos_workflow_events_channel" {
// Check if an entry exists in the map, indexed by the payload
// If yes, send a signal to the channel so the listener can wake up
if ch, exists := notificationsMap.Load(n.Payload); exists {
select {
case ch.(chan bool) <- true: // Send a signal to wake up the listener
default:
getLogger().Warn("notification channel for payload is full, skipping", "payload", n.Payload)
}
// If yes, broadcast on the condition variable so listeners can wake up
if cond, exists := notificationsMap.Load(n.Payload); exists {
cond.(*sync.Cond).Broadcast()
}
}
}
Expand Down Expand Up @@ -971,17 +969,17 @@ func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string
/****************************************/

func (s *systemDatabase) notificationListenerLoop(ctx context.Context) {
mrr := s.notificationListenerConnection.Exec(ctx, "LISTEN dbos_notifications_channel")
mrr := s.notificationListenerConnection.Exec(ctx, "LISTEN dbos_notifications_channel; LISTEN dbos_workflow_events_channel")
results, err := mrr.ReadAll()
if err != nil {
getLogger().Error("Failed to listen on dbos_notifications_channel", "error", err)
getLogger().Error("Failed to listen on notification channels", "error", err)
return
}
mrr.Close()

for _, result := range results {
if result.Err != nil {
getLogger().Error("Error listening on dbos_notifications_channel", "error", result.Err)
getLogger().Error("Error listening on notification channels", "error", result.Err)
return
}
}
Expand Down Expand Up @@ -1040,6 +1038,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro
return err
}
if recordedResult != nil {
// when hitting this case, recordedResult will be &{<nil> <nil>}
return nil
}

Expand Down Expand Up @@ -1114,19 +1113,12 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any
topic = input.Topic
}

tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)

// Check if operation was already executed
// XXX this might not need to be in the transaction
checkInput := checkOperationExecutionDBInput{
workflowID: destinationID,
operationID: stepID,
functionName: functionName,
tx: tx,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
Expand All @@ -1141,42 +1133,54 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any

// First check if there's already a receiver for this workflow/topic to avoid unnecessary database load
payload := fmt.Sprintf("%s::%s", destinationID, topic)
c := make(chan bool, 1) // Make it buffered to allow the notification listener to post a signal even if the receiver has not reached its select statement yet
_, loaded := s.notificationsMap.LoadOrStore(payload, c)
cond := sync.NewCond(&sync.Mutex{})
_, loaded := s.notificationsMap.LoadOrStore(payload, cond)
if loaded {
close(c)
getLogger().Error("Receive already called for workflow", "destination_id", destinationID)
return nil, newWorkflowConflictIDError(destinationID)
}
defer func() {
// Clean up the channel after we're done
// Clean up the condition variable after we're done and broadcast to wake up any waiting goroutines
// XXX We should handle panics in this function and make sure we call this. Not a problem for now as panic will crash the importing package.
cond.Broadcast()
s.notificationsMap.Delete(payload)
close(c)
}()

// Now check if there is already a message available in the database.
// If not, we'll wait for a notification and timeout
var exists bool
query := `SELECT EXISTS (SELECT 1 FROM dbos.notifications WHERE destination_uuid = $1 AND topic = $2)`
err = tx.QueryRow(ctx, query, destinationID, topic).Scan(&exists)
err = s.pool.QueryRow(ctx, query, destinationID, topic).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check message: %w", err)
}
if !exists {
// Listen for notifications on the channel
// Wait for notifications using condition variable with timeout pattern
// XXX should we prevent zero or negative timeouts?
getLogger().Debug("Waiting for notification on channel", "payload", payload)
getLogger().Debug("Waiting for notification on condition variable", "payload", payload)

done := make(chan struct{})
go func() {
cond.L.Lock()
defer cond.L.Unlock()
cond.Wait()
close(done)
}()

select {
case <-c:
getLogger().Debug("Received notification on channel", "payload", payload)
case <-done:
getLogger().Debug("Received notification on condition variable", "payload", payload)
case <-time.After(input.Timeout):
// If we reach the timeout, we can check if there is a message in the database, and if not return nil.
getLogger().Warn("Timeout reached for channel", "payload", payload, "timeout", input.Timeout)
getLogger().Warn("Recv() timeout reached", "payload", payload, "timeout", input.Timeout)
}
}

// Find the oldest message and delete it atomically
tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
Comment on lines +1179 to +1183
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transaction should start for the sequence:

  • consume message
  • record Recv() output

But not before

query = `
WITH oldest_entry AS (
SELECT destination_uuid, topic, message, created_at_epoch_ms
Expand Down Expand Up @@ -1231,6 +1235,199 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any
return message, nil
}

func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInput) error {
functionName := "DBOS.setEvent"

// Get workflow state from context
workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState)
if !ok || workflowState == nil {
return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?")
}

if workflowState.isWithinStep {
return newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call SetEvent within a step")
}

stepID := workflowState.NextStepID()

tx, err := s.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)

// Check if operation was already executed and do nothing if so
checkInput := checkOperationExecutionDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
functionName: functionName,
tx: tx,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
return err
}
if recordedResult != nil {
// when hitting this case, recordedResult will be &{<nil> <nil>}
return nil
}

// Serialize the message. It must have been registered with encoding/gob by the user if not a basic type.
messageString, err := serialize(input.Message)
if err != nil {
return fmt.Errorf("failed to serialize message: %w", err)
}

// Insert or update the event using UPSERT
insertQuery := `INSERT INTO dbos.workflow_events (workflow_uuid, key, value)
VALUES ($1, $2, $3)
ON CONFLICT (workflow_uuid, key)
DO UPDATE SET value = EXCLUDED.value`

_, err = tx.Exec(ctx, insertQuery, workflowState.WorkflowID, input.Key, messageString)
if err != nil {
return fmt.Errorf("failed to insert/update workflow event: %w", err)
}

// Record the operation result
recordInput := recordOperationResultDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
operationName: functionName,
output: nil,
err: nil,
tx: tx,
}

err = s.RecordOperationResult(ctx, recordInput)
if err != nil {
return fmt.Errorf("failed to record operation result: %w", err)
}

// Commit transaction
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}

return nil
}

func (s *systemDatabase) GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error) {
functionName := "DBOS.getEvent"

// Get workflow state from context (optional for GetEvent as we can get an event from outside a workflow)
workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState)
var stepID int
var isInWorkflow bool

if ok && workflowState != nil {
isInWorkflow = true
if workflowState.isWithinStep {
return nil, newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call GetEvent within a step")
}
stepID = workflowState.NextStepID()

// Check if operation was already executed (only if in workflow)
checkInput := checkOperationExecutionDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
functionName: functionName,
}
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
if err != nil {
return nil, err
}
if recordedResult != nil {
return recordedResult.output, recordedResult.err
}
}

// Create notification payload and condition variable
payload := fmt.Sprintf("%s::%s", input.TargetWorkflowID, input.Key)
cond := sync.NewCond(&sync.Mutex{})
existingCond, loaded := s.notificationsMap.LoadOrStore(payload, cond)
if loaded {
// Reuse the existing condition variable
cond = existingCond.(*sync.Cond)
}

// Defer broadcast to ensure any waiting goroutines eventually unlock
defer func() {
cond.Broadcast()
// Clean up the condition variable after we're done
s.notificationsMap.Delete(payload)
}()

// Check if the event already exists in the database
query := `SELECT value FROM dbos.workflow_events WHERE workflow_uuid = $1 AND key = $2`
var value any
var valueString *string

row := s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
err := row.Scan(&valueString)
if err != nil && err != pgx.ErrNoRows {
return nil, fmt.Errorf("failed to query workflow event: %w", err)
}

if err == pgx.ErrNoRows || valueString == nil { // XXX valueString should never be `nil`
// Wait for notification with timeout using condition variable
done := make(chan struct{})
go func() {
cond.L.Lock()
defer cond.L.Unlock()
cond.Wait()
close(done)
}()

select {
case <-done:
// Received notification
case <-time.After(input.Timeout):
// Timeout reached
getLogger().Warn("GetEvent() timeout reached", "target_workflow_id", input.TargetWorkflowID, "key", input.Key, "timeout", input.Timeout)
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled while waiting for event: %w", ctx.Err())
}

// Query the database again after waiting
row = s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
err = row.Scan(&valueString)
if err != nil {
if err == pgx.ErrNoRows {
value = nil // Event still doesn't exist
} else {
return nil, fmt.Errorf("failed to query workflow event after wait: %w", err)
}
}
}

// Deserialize the value if it exists
if valueString != nil {
value, err = deserialize(valueString)
if err != nil {
return nil, fmt.Errorf("failed to deserialize event value: %w", err)
}
}

// Record the operation result if this is called within a workflow
if isInWorkflow {
recordInput := recordOperationResultDBInput{
workflowID: workflowState.WorkflowID,
operationID: stepID,
operationName: functionName,
output: value,
err: nil,
}

err = s.RecordOperationResult(ctx, recordInput)
if err != nil {
return nil, fmt.Errorf("failed to record operation result: %w", err)
}
}

return value, nil
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trust recording the step outcome + reading the event do not need being done transactionally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they don't, as it doesn't consume the event

/*******************************/
/******* QUEUES ********/
/*******************************/
Expand Down
Loading