Skip to content

Commit c3ded62

Browse files
committed
dequeueWorkflow accepts an input struct
1 parent ab763ff commit c3ded62

File tree

3 files changed

+35
-27
lines changed

3 files changed

+35
-27
lines changed

dbos/queue.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,11 @@ func (qr *queueRunner) run(ctx *dbosContext) {
179179
// Iterate through all queues in the registry
180180
for queueName, queue := range qr.workflowQueueRegistry {
181181
// Call DequeueWorkflows for each queue
182-
dequeuedWorkflows, err := ctx.systemDB.dequeueWorkflows(ctx, queue, ctx.executorID, ctx.applicationVersion)
182+
dequeuedWorkflows, err := ctx.systemDB.dequeueWorkflows(ctx, dequeueWorkflowsInput{
183+
queue: queue,
184+
executorID: ctx.executorID,
185+
applicationVersion: ctx.applicationVersion,
186+
})
183187
if err != nil {
184188
if pgErr, ok := err.(*pgconn.PgError); ok {
185189
switch pgErr.Code {

dbos/system_database.go

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ type systemDatabase interface {
5858
sleep(ctx context.Context, duration time.Duration) (time.Duration, error)
5959

6060
// Queues
61-
dequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error)
61+
dequeueWorkflows(ctx context.Context, input dequeueWorkflowsInput) ([]dequeuedWorkflow, error)
6262
clearQueueAssignment(ctx context.Context, workflowID string) (bool, error)
6363
}
6464

@@ -1864,8 +1864,13 @@ type dequeuedWorkflow struct {
18641864
input string
18651865
}
18661866

1867-
// TODO input struct
1868-
func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, executorID, applicationVersion string) ([]dequeuedWorkflow, error) {
1867+
type dequeueWorkflowsInput struct {
1868+
queue WorkflowQueue
1869+
executorID string
1870+
applicationVersion string
1871+
}
1872+
1873+
func (s *sysDB) dequeueWorkflows(ctx context.Context, input dequeueWorkflowsInput) ([]dequeuedWorkflow, error) {
18691874
// Begin transaction with snapshot isolation
18701875
tx, err := s.pool.Begin(ctx)
18711876
if err != nil {
@@ -1882,8 +1887,8 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
18821887
// First check the rate limiter
18831888
startTimeMs := time.Now().UnixMilli()
18841889
var numRecentQueries int
1885-
if queue.RateLimit != nil {
1886-
limiterPeriod := time.Duration(queue.RateLimit.Period * float64(time.Second))
1890+
if input.queue.RateLimit != nil {
1891+
limiterPeriod := time.Duration(input.queue.RateLimit.Period * float64(time.Second))
18871892

18881893
// Calculate the cutoff time: current time minus limiter period
18891894
cutoffTimeMs := time.Now().Add(-limiterPeriod).UnixMilli()
@@ -1897,30 +1902,30 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
18971902
AND started_at_epoch_ms > $3`
18981903

18991904
err := tx.QueryRow(ctx, limiterQuery,
1900-
queue.Name,
1905+
input.queue.Name,
19011906
WorkflowStatusEnqueued,
19021907
cutoffTimeMs).Scan(&numRecentQueries)
19031908
if err != nil {
19041909
return nil, fmt.Errorf("failed to query rate limiter: %w", err)
19051910
}
19061911

1907-
if numRecentQueries >= queue.RateLimit.Limit {
1912+
if numRecentQueries >= input.queue.RateLimit.Limit {
19081913
return []dequeuedWorkflow{}, nil
19091914
}
19101915
}
19111916

19121917
// Calculate max_tasks based on concurrency limits
1913-
maxTasks := queue.MaxTasksPerIteration
1918+
maxTasks := input.queue.MaxTasksPerIteration
19141919

1915-
if queue.WorkerConcurrency != nil || queue.GlobalConcurrency != nil {
1920+
if input.queue.WorkerConcurrency != nil || input.queue.GlobalConcurrency != nil {
19161921
// Count pending workflows by executor
19171922
pendingQuery := `
19181923
SELECT executor_id, COUNT(*) as task_count
19191924
FROM dbos.workflow_status
19201925
WHERE queue_name = $1 AND status = $2
19211926
GROUP BY executor_id`
19221927

1923-
rows, err := tx.Query(ctx, pendingQuery, queue.Name, WorkflowStatusPending)
1928+
rows, err := tx.Query(ctx, pendingQuery, input.queue.Name, WorkflowStatusPending)
19241929
if err != nil {
19251930
return nil, fmt.Errorf("failed to query pending workflows: %w", err)
19261931
}
@@ -1936,28 +1941,28 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
19361941
pendingWorkflowsDict[executorIDRow] = taskCount
19371942
}
19381943

1939-
localPendingWorkflows := pendingWorkflowsDict[executorID]
1944+
localPendingWorkflows := pendingWorkflowsDict[input.executorID]
19401945

19411946
// Check worker concurrency limit
1942-
if queue.WorkerConcurrency != nil {
1943-
workerConcurrency := *queue.WorkerConcurrency
1947+
if input.queue.WorkerConcurrency != nil {
1948+
workerConcurrency := *input.queue.WorkerConcurrency
19441949
if localPendingWorkflows > workerConcurrency {
1945-
s.logger.Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", workerConcurrency)
1950+
s.logger.Warn("Local pending workflows on queue exceeds worker concurrency limit", "local_pending", localPendingWorkflows, "queue_name", input.queue.Name, "concurrency_limit", workerConcurrency)
19461951
}
19471952
availableWorkerTasks := max(workerConcurrency-localPendingWorkflows, 0)
19481953
maxTasks = availableWorkerTasks
19491954
}
19501955

19511956
// Check global concurrency limit
1952-
if queue.GlobalConcurrency != nil {
1957+
if input.queue.GlobalConcurrency != nil {
19531958
globalPendingWorkflows := 0
19541959
for _, count := range pendingWorkflowsDict {
19551960
globalPendingWorkflows += count
19561961
}
19571962

1958-
concurrency := *queue.GlobalConcurrency
1963+
concurrency := *input.queue.GlobalConcurrency
19591964
if globalPendingWorkflows > concurrency {
1960-
s.logger.Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", queue.Name, "concurrency_limit", concurrency)
1965+
s.logger.Warn("Total pending workflows on queue exceeds global concurrency limit", "total_pending", globalPendingWorkflows, "queue_name", input.queue.Name, "concurrency_limit", concurrency)
19611966
}
19621967
availableTasks := max(concurrency-globalPendingWorkflows, 0)
19631968
if availableTasks < maxTasks {
@@ -1969,7 +1974,7 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
19691974
// Build the query to select workflows for dequeueing
19701975
// Use SKIP LOCKED when no global concurrency is set to avoid blocking,
19711976
// otherwise use NOWAIT to ensure consistent view across processes
1972-
skipLocks := queue.GlobalConcurrency == nil
1977+
skipLocks := input.queue.GlobalConcurrency == nil
19731978
var lockClause string
19741979
if skipLocks {
19751980
lockClause = "FOR UPDATE SKIP LOCKED"
@@ -1978,7 +1983,7 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
19781983
}
19791984

19801985
var query string
1981-
if queue.PriorityEnabled {
1986+
if input.queue.PriorityEnabled {
19821987
query = fmt.Sprintf(`
19831988
SELECT workflow_uuid
19841989
FROM dbos.workflow_status
@@ -2003,7 +2008,7 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
20032008
}
20042009

20052010
// Execute the query to get workflow IDs
2006-
rows, err := tx.Query(ctx, query, queue.Name, WorkflowStatusEnqueued, applicationVersion)
2011+
rows, err := tx.Query(ctx, query, input.queue.Name, WorkflowStatusEnqueued, input.applicationVersion)
20072012
if err != nil {
20082013
return nil, fmt.Errorf("failed to query enqueued workflows: %w", err)
20092014
}
@@ -2026,15 +2031,15 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
20262031
}
20272032

20282033
if len(dequeuedIDs) > 0 {
2029-
s.logger.Debug("attempting to dequeue task(s)", "queueName", queue.Name, "numTasks", len(dequeuedIDs))
2034+
s.logger.Debug("attempting to dequeue task(s)", "queueName", input.queue.Name, "numTasks", len(dequeuedIDs))
20302035
}
20312036

20322037
// Update workflows to PENDING status and get their details
20332038
var retWorkflows []dequeuedWorkflow
20342039
for _, id := range dequeuedIDs {
20352040
// If we have a limiter, stop dequeueing workflows when the number of workflows started this period exceeds the limit.
2036-
if queue.RateLimit != nil {
2037-
if len(retWorkflows)+numRecentQueries >= queue.RateLimit.Limit {
2041+
if input.queue.RateLimit != nil {
2042+
if len(retWorkflows)+numRecentQueries >= input.queue.RateLimit.Limit {
20382043
break
20392044
}
20402045
}
@@ -2060,8 +2065,8 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, queue WorkflowQueue, execu
20602065
var inputString *string
20612066
err := tx.QueryRow(ctx, updateQuery,
20622067
WorkflowStatusPending,
2063-
applicationVersion,
2064-
executorID,
2068+
input.applicationVersion,
2069+
input.executorID,
20652070
startTimeMs,
20662071
id).Scan(&retWorkflow.name, &inputString)
20672072

dbos/workflow.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) {
182182
if result != nil {
183183
typedResult, ok := result.(R)
184184
if !ok {
185-
// TODO check what this looks like in practice
186185
return *new(R), newWorkflowUnexpectedResultType(h.workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", result))
187186
}
188187
// If we are calling GetResult inside a workflow, record the result as a step result

0 commit comments

Comments
 (0)