Skip to content

Commit 1d134da

Browse files
authored
Send get events (#38)
- Implement set/get events - Modify the notification system to use condition variables: while send/recv work fine with 1 producer: 1 consumer, set/get events allow for concurrent consumers. - Add a log entry when `dbos.Launch()` recover workflows. - Fix transaction usage in `Recv()`
1 parent 95f06a8 commit 1d134da

File tree

6 files changed

+753
-75
lines changed

6 files changed

+753
-75
lines changed

dbos/dbos.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,13 @@ func Launch() error {
211211
}
212212

213213
// Run a round of recovery on the local executor
214-
_, err := recoverPendingWorkflows(context.Background(), []string{_EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it?
214+
recoveryHandles, err := recoverPendingWorkflows(context.Background(), []string{_EXECUTOR_ID}) // XXX maybe use the queue runner context here to allow Shutdown to cancel it?
215215
if err != nil {
216216
return newInitializationError(fmt.Sprintf("failed to recover pending workflows during launch: %v", err))
217217
}
218+
if len(recoveryHandles) > 0 {
219+
logger.Info("Recovered pending workflows", "count", len(recoveryHandles))
220+
}
218221

219222
logger.Info("DBOS initialized", "app_version", _APP_VERSION, "executor_id", _EXECUTOR_ID)
220223
return nil

dbos/migrations/000001_initial_dbos_schema.down.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
-- 001_initial_dbos_schema.down.sql
22

3-
-- Drop trigger first
3+
-- Drop triggers first
44
DROP TRIGGER IF EXISTS dbos_notifications_trigger ON dbos.notifications;
5+
DROP TRIGGER IF EXISTS dbos_workflow_events_trigger ON dbos.workflow_events;
56

67
-- Drop function
78
DROP FUNCTION IF EXISTS dbos.notifications_function();
9+
DROP FUNCTION IF EXISTS dbos.workflow_events_function();
810

911
-- Drop tables in reverse order to respect foreign key constraints
1012
DROP TABLE IF EXISTS dbos.workflow_events;

dbos/migrations/000001_initial_dbos_schema.up.sql

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,19 @@ CREATE TABLE dbos.workflow_events (
9191
PRIMARY KEY (workflow_uuid, key),
9292
FOREIGN KEY (workflow_uuid) REFERENCES dbos.workflow_status(workflow_uuid)
9393
ON UPDATE CASCADE ON DELETE CASCADE
94-
);
94+
);
95+
96+
-- Create events function
97+
CREATE OR REPLACE FUNCTION dbos.workflow_events_function() RETURNS TRIGGER AS $$
98+
DECLARE
99+
payload text := NEW.workflow_uuid || '::' || NEW.key;
100+
BEGIN
101+
PERFORM pg_notify('dbos_workflow_events_channel', payload);
102+
RETURN NEW;
103+
END;
104+
$$ LANGUAGE plpgsql;
105+
106+
-- Create events trigger
107+
CREATE TRIGGER dbos_workflow_events_trigger
108+
AFTER INSERT ON dbos.workflow_events
109+
FOR EACH ROW EXECUTE FUNCTION dbos.workflow_events_function();

dbos/system_database.go

Lines changed: 229 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ type SystemDatabase interface {
3939
GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error)
4040
Send(ctx context.Context, input WorkflowSendInput) error
4141
Recv(ctx context.Context, input WorkflowRecvInput) (any, error)
42+
SetEvent(ctx context.Context, input WorkflowSetEventInput) error
43+
GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error)
4244
}
4345

4446
type systemDatabase struct {
@@ -157,15 +159,11 @@ func NewSystemDatabase(databaseURL string) (SystemDatabase, error) {
157159
return nil, fmt.Errorf("failed to parse database URL: %v", err)
158160
}
159161
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
160-
if n.Channel == "dbos_notifications_channel" {
162+
if n.Channel == "dbos_notifications_channel" || n.Channel == "dbos_workflow_events_channel" {
161163
// Check if an entry exists in the map, indexed by the payload
162-
// If yes, send a signal to the channel so the listener can wake up
163-
if ch, exists := notificationsMap.Load(n.Payload); exists {
164-
select {
165-
case ch.(chan bool) <- true: // Send a signal to wake up the listener
166-
default:
167-
getLogger().Warn("notification channel for payload is full, skipping", "payload", n.Payload)
168-
}
164+
// If yes, broadcast on the condition variable so listeners can wake up
165+
if cond, exists := notificationsMap.Load(n.Payload); exists {
166+
cond.(*sync.Cond).Broadcast()
169167
}
170168
}
171169
}
@@ -971,17 +969,17 @@ func (s *systemDatabase) GetWorkflowSteps(ctx context.Context, workflowID string
971969
/****************************************/
972970

973971
func (s *systemDatabase) notificationListenerLoop(ctx context.Context) {
974-
mrr := s.notificationListenerConnection.Exec(ctx, "LISTEN dbos_notifications_channel")
972+
mrr := s.notificationListenerConnection.Exec(ctx, "LISTEN dbos_notifications_channel; LISTEN dbos_workflow_events_channel")
975973
results, err := mrr.ReadAll()
976974
if err != nil {
977-
getLogger().Error("Failed to listen on dbos_notifications_channel", "error", err)
975+
getLogger().Error("Failed to listen on notification channels", "error", err)
978976
return
979977
}
980978
mrr.Close()
981979

982980
for _, result := range results {
983981
if result.Err != nil {
984-
getLogger().Error("Error listening on dbos_notifications_channel", "error", result.Err)
982+
getLogger().Error("Error listening on notification channels", "error", result.Err)
985983
return
986984
}
987985
}
@@ -1040,6 +1038,7 @@ func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) erro
10401038
return err
10411039
}
10421040
if recordedResult != nil {
1041+
// when hitting this case, recordedResult will be &{<nil> <nil>}
10431042
return nil
10441043
}
10451044

@@ -1114,19 +1113,12 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any
11141113
topic = input.Topic
11151114
}
11161115

1117-
tx, err := s.pool.Begin(ctx)
1118-
if err != nil {
1119-
return nil, fmt.Errorf("failed to begin transaction: %w", err)
1120-
}
1121-
defer tx.Rollback(ctx)
1122-
11231116
// Check if operation was already executed
11241117
// XXX this might not need to be in the transaction
11251118
checkInput := checkOperationExecutionDBInput{
11261119
workflowID: destinationID,
11271120
operationID: stepID,
11281121
functionName: functionName,
1129-
tx: tx,
11301122
}
11311123
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
11321124
if err != nil {
@@ -1141,42 +1133,54 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any
11411133

11421134
// First check if there's already a receiver for this workflow/topic to avoid unnecessary database load
11431135
payload := fmt.Sprintf("%s::%s", destinationID, topic)
1144-
c := make(chan bool, 1) // Make it buffered to allow the notification listener to post a signal even if the receiver has not reached its select statement yet
1145-
_, loaded := s.notificationsMap.LoadOrStore(payload, c)
1136+
cond := sync.NewCond(&sync.Mutex{})
1137+
_, loaded := s.notificationsMap.LoadOrStore(payload, cond)
11461138
if loaded {
1147-
close(c)
11481139
getLogger().Error("Receive already called for workflow", "destination_id", destinationID)
11491140
return nil, newWorkflowConflictIDError(destinationID)
11501141
}
11511142
defer func() {
1152-
// Clean up the channel after we're done
1143+
// Clean up the condition variable after we're done and broadcast to wake up any waiting goroutines
11531144
// XXX We should handle panics in this function and make sure we call this. Not a problem for now as panic will crash the importing package.
1145+
cond.Broadcast()
11541146
s.notificationsMap.Delete(payload)
1155-
close(c)
11561147
}()
11571148

11581149
// Now check if there is already a message available in the database.
11591150
// If not, we'll wait for a notification and timeout
11601151
var exists bool
11611152
query := `SELECT EXISTS (SELECT 1 FROM dbos.notifications WHERE destination_uuid = $1 AND topic = $2)`
1162-
err = tx.QueryRow(ctx, query, destinationID, topic).Scan(&exists)
1153+
err = s.pool.QueryRow(ctx, query, destinationID, topic).Scan(&exists)
11631154
if err != nil {
11641155
return false, fmt.Errorf("failed to check message: %w", err)
11651156
}
11661157
if !exists {
1167-
// Listen for notifications on the channel
1158+
// Wait for notifications using condition variable with timeout pattern
11681159
// XXX should we prevent zero or negative timeouts?
1169-
getLogger().Debug("Waiting for notification on channel", "payload", payload)
1160+
getLogger().Debug("Waiting for notification on condition variable", "payload", payload)
1161+
1162+
done := make(chan struct{})
1163+
go func() {
1164+
cond.L.Lock()
1165+
defer cond.L.Unlock()
1166+
cond.Wait()
1167+
close(done)
1168+
}()
1169+
11701170
select {
1171-
case <-c:
1172-
getLogger().Debug("Received notification on channel", "payload", payload)
1171+
case <-done:
1172+
getLogger().Debug("Received notification on condition variable", "payload", payload)
11731173
case <-time.After(input.Timeout):
1174-
// If we reach the timeout, we can check if there is a message in the database, and if not return nil.
1175-
getLogger().Warn("Timeout reached for channel", "payload", payload, "timeout", input.Timeout)
1174+
getLogger().Warn("Recv() timeout reached", "payload", payload, "timeout", input.Timeout)
11761175
}
11771176
}
11781177

11791178
// Find the oldest message and delete it atomically
1179+
tx, err := s.pool.Begin(ctx)
1180+
if err != nil {
1181+
return nil, fmt.Errorf("failed to begin transaction: %w", err)
1182+
}
1183+
defer tx.Rollback(ctx)
11801184
query = `
11811185
WITH oldest_entry AS (
11821186
SELECT destination_uuid, topic, message, created_at_epoch_ms
@@ -1231,6 +1235,201 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any
12311235
return message, nil
12321236
}
12331237

1238+
func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInput) error {
1239+
functionName := "DBOS.setEvent"
1240+
1241+
// Get workflow state from context
1242+
workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState)
1243+
if !ok || workflowState == nil {
1244+
return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?")
1245+
}
1246+
1247+
if workflowState.isWithinStep {
1248+
return newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call SetEvent within a step")
1249+
}
1250+
1251+
stepID := workflowState.NextStepID()
1252+
1253+
tx, err := s.pool.Begin(ctx)
1254+
if err != nil {
1255+
return fmt.Errorf("failed to begin transaction: %w", err)
1256+
}
1257+
defer tx.Rollback(ctx)
1258+
1259+
// Check if operation was already executed and do nothing if so
1260+
checkInput := checkOperationExecutionDBInput{
1261+
workflowID: workflowState.WorkflowID,
1262+
operationID: stepID,
1263+
functionName: functionName,
1264+
tx: tx,
1265+
}
1266+
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
1267+
if err != nil {
1268+
return err
1269+
}
1270+
if recordedResult != nil {
1271+
// when hitting this case, recordedResult will be &{<nil> <nil>}
1272+
return nil
1273+
}
1274+
1275+
// Serialize the message. It must have been registered with encoding/gob by the user if not a basic type.
1276+
messageString, err := serialize(input.Message)
1277+
if err != nil {
1278+
return fmt.Errorf("failed to serialize message: %w", err)
1279+
}
1280+
1281+
// Insert or update the event using UPSERT
1282+
insertQuery := `INSERT INTO dbos.workflow_events (workflow_uuid, key, value)
1283+
VALUES ($1, $2, $3)
1284+
ON CONFLICT (workflow_uuid, key)
1285+
DO UPDATE SET value = EXCLUDED.value`
1286+
1287+
_, err = tx.Exec(ctx, insertQuery, workflowState.WorkflowID, input.Key, messageString)
1288+
if err != nil {
1289+
return fmt.Errorf("failed to insert/update workflow event: %w", err)
1290+
}
1291+
1292+
// Record the operation result
1293+
recordInput := recordOperationResultDBInput{
1294+
workflowID: workflowState.WorkflowID,
1295+
operationID: stepID,
1296+
operationName: functionName,
1297+
output: nil,
1298+
err: nil,
1299+
tx: tx,
1300+
}
1301+
1302+
err = s.RecordOperationResult(ctx, recordInput)
1303+
if err != nil {
1304+
return fmt.Errorf("failed to record operation result: %w", err)
1305+
}
1306+
1307+
// Commit transaction
1308+
if err := tx.Commit(ctx); err != nil {
1309+
return fmt.Errorf("failed to commit transaction: %w", err)
1310+
}
1311+
1312+
return nil
1313+
}
1314+
1315+
func (s *systemDatabase) GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error) {
1316+
functionName := "DBOS.getEvent"
1317+
1318+
// Get workflow state from context (optional for GetEvent as we can get an event from outside a workflow)
1319+
workflowState, ok := ctx.Value(WorkflowStateKey).(*WorkflowState)
1320+
var stepID int
1321+
var isInWorkflow bool
1322+
1323+
if ok && workflowState != nil {
1324+
isInWorkflow = true
1325+
if workflowState.isWithinStep {
1326+
return nil, newStepExecutionError(workflowState.WorkflowID, functionName, "cannot call GetEvent within a step")
1327+
}
1328+
stepID = workflowState.NextStepID()
1329+
1330+
// Check if operation was already executed (only if in workflow)
1331+
checkInput := checkOperationExecutionDBInput{
1332+
workflowID: workflowState.WorkflowID,
1333+
operationID: stepID,
1334+
functionName: functionName,
1335+
}
1336+
recordedResult, err := s.CheckOperationExecution(ctx, checkInput)
1337+
if err != nil {
1338+
return nil, err
1339+
}
1340+
if recordedResult != nil {
1341+
return recordedResult.output, recordedResult.err
1342+
}
1343+
}
1344+
1345+
// Create notification payload and condition variable
1346+
payload := fmt.Sprintf("%s::%s", input.TargetWorkflowID, input.Key)
1347+
cond := sync.NewCond(&sync.Mutex{})
1348+
existingCond, loaded := s.notificationsMap.LoadOrStore(payload, cond)
1349+
if loaded {
1350+
// Reuse the existing condition variable
1351+
cond = existingCond.(*sync.Cond)
1352+
}
1353+
1354+
// Defer broadcast to ensure any waiting goroutines eventually unlock
1355+
defer func() {
1356+
cond.Broadcast()
1357+
// Clean up the condition variable after we're done, if we created it
1358+
if !loaded {
1359+
s.notificationsMap.Delete(payload)
1360+
}
1361+
}()
1362+
1363+
// Check if the event already exists in the database
1364+
query := `SELECT value FROM dbos.workflow_events WHERE workflow_uuid = $1 AND key = $2`
1365+
var value any
1366+
var valueString *string
1367+
1368+
row := s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
1369+
err := row.Scan(&valueString)
1370+
if err != nil && err != pgx.ErrNoRows {
1371+
return nil, fmt.Errorf("failed to query workflow event: %w", err)
1372+
}
1373+
1374+
if err == pgx.ErrNoRows || valueString == nil { // XXX valueString should never be `nil`
1375+
// Wait for notification with timeout using condition variable
1376+
done := make(chan struct{})
1377+
go func() {
1378+
cond.L.Lock()
1379+
defer cond.L.Unlock()
1380+
cond.Wait()
1381+
close(done)
1382+
}()
1383+
1384+
select {
1385+
case <-done:
1386+
// Received notification
1387+
case <-time.After(input.Timeout):
1388+
// Timeout reached
1389+
getLogger().Warn("GetEvent() timeout reached", "target_workflow_id", input.TargetWorkflowID, "key", input.Key, "timeout", input.Timeout)
1390+
case <-ctx.Done():
1391+
return nil, fmt.Errorf("context cancelled while waiting for event: %w", ctx.Err())
1392+
}
1393+
1394+
// Query the database again after waiting
1395+
row = s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
1396+
err = row.Scan(&valueString)
1397+
if err != nil {
1398+
if err == pgx.ErrNoRows {
1399+
value = nil // Event still doesn't exist
1400+
} else {
1401+
return nil, fmt.Errorf("failed to query workflow event after wait: %w", err)
1402+
}
1403+
}
1404+
}
1405+
1406+
// Deserialize the value if it exists
1407+
if valueString != nil {
1408+
value, err = deserialize(valueString)
1409+
if err != nil {
1410+
return nil, fmt.Errorf("failed to deserialize event value: %w", err)
1411+
}
1412+
}
1413+
1414+
// Record the operation result if this is called within a workflow
1415+
if isInWorkflow {
1416+
recordInput := recordOperationResultDBInput{
1417+
workflowID: workflowState.WorkflowID,
1418+
operationID: stepID,
1419+
operationName: functionName,
1420+
output: value,
1421+
err: nil,
1422+
}
1423+
1424+
err = s.RecordOperationResult(ctx, recordInput)
1425+
if err != nil {
1426+
return nil, fmt.Errorf("failed to record operation result: %w", err)
1427+
}
1428+
}
1429+
1430+
return value, nil
1431+
}
1432+
12341433
/*******************************/
12351434
/******* QUEUES ********/
12361435
/*******************************/

0 commit comments

Comments
 (0)