Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dbos/_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
execute_workflow_by_id(dbos, id)
except OperationalError as e:
# Ignore serialization error
if not isinstance(e.orig, errors.SerializationFailure):
if not isinstance(
e.orig, (errors.SerializationFailure, errors.LockNotAvailable)
):
dbos.logger.warning(
f"Exception encountered in queue thread: {traceback.format_exc()}"
)
Expand Down
85 changes: 69 additions & 16 deletions dbos/_sys_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def __init__(self, config: ConfigFile):
host=config["database"]["hostname"],
port=config["database"]["port"],
database="postgres",
# fills the "application_name" column in pg_stat_activity
query={
"application_name": f"dbos_transact_{os.environ.get('DBOS__VMID', 'local')}_{os.environ.get('DBOS__APPVERSION', '')}"
},
)
engine = sa.create_engine(postgres_db_url)
with engine.connect() as conn:
Expand All @@ -207,6 +211,10 @@ def __init__(self, config: ConfigFile):
host=config["database"]["hostname"],
port=config["database"]["port"],
database=sysdb_name,
# fills the "application_name" column in pg_stat_activity
query={
"application_name": f"dbos_transact_{os.environ.get('DBOS__VMID', 'local')}_{os.environ.get('DBOS__APPVERSION', '')}"
},
)

# Create a connection pool for the system database
Expand Down Expand Up @@ -1307,6 +1315,55 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]:
# Dequeue functions eligible for this worker and ordered by the time at which they were enqueued.
# If there is a global or local concurrency limit N, select only the N oldest enqueued
# functions, else select all of them.

# First lets figure out how many tasks the worker can dequeue
running_tasks_query = (
sa.select(
SystemSchema.workflow_queue.c.executor_id,
sa.func.count().label("task_count"),
)
.where(SystemSchema.workflow_queue.c.queue_name == queue.name)
.where(
SystemSchema.workflow_queue.c.executor_id.isnot(
None
) # Task is dequeued
)
.where(
SystemSchema.workflow_queue.c.completed_at_epoch_ms.is_(
None
) # Task is not completed
)
.group_by(SystemSchema.workflow_queue.c.executor_id)
)
running_tasks_result = c.execute(running_tasks_query).fetchall()
running_tasks_result_dict = {row[0]: row[1] for row in running_tasks_result}
running_tasks_for_this_worker = running_tasks_result_dict.get(
executor_id, 0
) # Get count for current executor

max_tasks = float("inf")
if queue.worker_concurrency is not None:
# Worker local concurrency limit should always be >= running_tasks_for_this_worker
# This should never happen but a check + warning doesn't hurt
if running_tasks_for_this_worker > queue.worker_concurrency:
dbos_logger.warning(
f"Number of tasks on this worker ({running_tasks_for_this_worker}) exceeds the worker concurrency limit ({queue.worker_concurrency})"
)
max_tasks = max(
0, queue.worker_concurrency - running_tasks_for_this_worker
)
if queue.concurrency is not None:
total_running_tasks = sum(running_tasks_result_dict.values())
# Queue global concurrency limit should always be >= running_tasks_count
# This should never happen but a check + warning doesn't hurt
if total_running_tasks > queue.concurrency:
dbos_logger.warning(
f"Total running tasks ({total_running_tasks}) exceeds the global concurrency limit ({queue.concurrency})"
)
available_tasks = max(0, queue.concurrency - total_running_tasks)
max_tasks = min(max_tasks, available_tasks)

# Lookup tasks
query = (
sa.select(
SystemSchema.workflow_queue.c.workflow_uuid,
Expand All @@ -1315,29 +1372,25 @@ def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]:
)
.where(SystemSchema.workflow_queue.c.queue_name == queue.name)
.where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None)
.where(
# Only select functions that have not been started yet or have been started by this worker
or_(
SystemSchema.workflow_queue.c.executor_id == None,
SystemSchema.workflow_queue.c.executor_id == executor_id,
)
)
.where(SystemSchema.workflow_queue.c.executor_id == None)
.order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc())
.with_for_update(nowait=True) # Error out early
)
# Set a dequeue limit if necessary
if queue.worker_concurrency is not None:
query = query.limit(queue.worker_concurrency)
elif queue.concurrency is not None:
query = query.limit(queue.concurrency)
# Apply limit only if max_tasks is finite
if max_tasks != float("inf"):
query = query.limit(int(max_tasks))

rows = c.execute(query).fetchall()

# Now, get the workflow IDs of functions that have not yet been started
dequeued_ids: List[str] = [row[0] for row in rows if row[1] is None]
# Get the workflow IDs
dequeued_ids: List[str] = [row[0] for row in rows]
if len(dequeued_ids) > 0:
dbos_logger.debug(
f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)"
)
ret_ids: list[str] = []
dbos_logger.debug(f"[{queue.name}] dequeueing {len(dequeued_ids)} task(s)")
for id in dequeued_ids:

for id in dequeued_ids:
# If we have a limiter, stop starting functions when the number
# of functions started this period exceeds the limit.
if queue.limiter is not None:
Expand Down
163 changes: 139 additions & 24 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import multiprocessing
import multiprocessing.synchronize
import os
import subprocess
import threading
import time
import uuid
from multiprocessing import Process

import pytest
import sqlalchemy as sa
Expand All @@ -17,7 +18,6 @@
SetWorkflowID,
WorkflowHandle,
)
from dbos._error import DBOSDeadLetterQueueError
from dbos._schemas.system_database import SystemSchema
from dbos._sys_db import WorkflowStatusString
from tests.conftest import default_config
Expand Down Expand Up @@ -379,11 +379,6 @@ def test_queue_workflow_in_recovered_workflow(dbos: DBOS) -> None:
return


###########################
# TEST WORKER CONCURRENCY #
###########################


def test_one_at_a_time_with_worker_concurrency(dbos: DBOS) -> None:
wf_counter = 0
flag = False
Expand Down Expand Up @@ -423,12 +418,25 @@ def workflow_two() -> None:


# Declare a workflow globally (we need it to be registered across process under a known name)
start_event = threading.Event()
end_event = threading.Event()


@DBOS.workflow()
def worker_concurrency_test_workflow() -> None:
pass
start_event.set()
end_event.wait()


def run_dbos_test_in_process(i: int) -> None:
local_concurrency_limit: int = 5
global_concurrency_limit: int = local_concurrency_limit * 2


def run_dbos_test_in_process(
i: int,
start_signal: multiprocessing.synchronize.Event,
end_signal: multiprocessing.synchronize.Event,
) -> None:
dbos_config: ConfigFile = {
"name": "test-app",
"language": "python",
Expand All @@ -445,39 +453,144 @@ def run_dbos_test_in_process(i: int) -> None:
},
"telemetry": {},
"env": {},
"application": {},
}
dbos = DBOS(config=dbos_config)
DBOS.launch()

Queue("test_queue", worker_concurrency=1)
time.sleep(
2
) # Give some time for the parent worker to enqueue and for this worker to dequeue

Queue(
"test_queue",
worker_concurrency=local_concurrency_limit,
concurrency=global_concurrency_limit,
)
# Wait to dequeue as many tasks as we can locally
for _ in range(0, local_concurrency_limit):
start_event.wait()
start_event.clear()
# Signal the parent process we've dequeued
start_signal.set()
# Wait for the parent process to signal we can move on
end_signal.wait()
# Complete the task. 1 set should unblock them all
end_event.set()

# Now whatever is in the queue should be cleared up fast (start/end events are already set)
queue_entries_are_cleaned_up(dbos)

DBOS.destroy()


# Test global concurrency and worker utilization by carefully filling the queue up to 1) the local limit 2) the global limit
# For the global limit, we fill the queue in 2 steps, ensuring that the 2nd worker is able to cap its local utilization even
# after having dequeued some tasks already
def test_worker_concurrency_with_n_dbos_instances(dbos: DBOS) -> None:
# Ensure children processes do not share global variables (including DBOS instance) with the parent
multiprocessing.set_start_method("spawn")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found out that GHA uses "fork", which ofc shares the parent's global variable with the child, including the DBOS global instance, which would result in a DBOS Initialized multiple times with conflicting configuration / fastapi information error in the children.


queue = Queue(
"test_queue", limiter={"limit": 0, "period": 1}
) # This process cannot dequeue tasks

# Start N proccesses to dequeue
# First, start local concurrency limit tasks
handles = []
for _ in range(0, local_concurrency_limit):
handles.append(queue.enqueue(worker_concurrency_test_workflow))

# Start 2 workers
processes = []
for i in range(0, 10):
start_signals = []
end_signals = []
manager = multiprocessing.Manager()
for i in range(0, 2):
os.environ["DBOS__VMID"] = f"test-executor-{i}"
process = Process(target=run_dbos_test_in_process, args=(i,))
start_signal = manager.Event()
start_signals.append(start_signal)
end_signal = manager.Event()
end_signals.append(end_signal)
process = multiprocessing.Process(
target=run_dbos_test_in_process, args=(i, start_signal, end_signal)
)
process.start()
processes.append(process)
del os.environ["DBOS__VMID"]

# Check that a single worker was able to acquire all the tasks
loop = True
while loop:
for signal in start_signals:
signal.wait(timeout=1)
if signal.is_set():
loop = False
executors = []
for handle in handles:
status = handle.get_status()
assert status.status == WorkflowStatusString.PENDING.value
executors.append(status.executor_id)
assert len(set(executors)) == 1

# Now enqueue less than the local concurrency limit. Check that the 2nd worker acquired them. We won't have a signal set from the worker so we need to sleep a little.
handles = []
for _ in range(0, local_concurrency_limit - 1):
handles.append(queue.enqueue(worker_concurrency_test_workflow))
time.sleep(2)
executors = []
for handle in handles:
status = handle.get_status()
assert status.status == WorkflowStatusString.PENDING.value
executors.append(status.executor_id)
assert len(set(executors)) == 1

# Now, enqueue two more tasks. This means qlen > local concurrency limit * 2 and qlen > global concurrency limit
# We should have 1 tasks PENDING and 1 ENQUEUED, thus meeting both local and global concurrency limits
handles = []
for _ in range(0, 2):
handles.append(queue.enqueue(worker_concurrency_test_workflow))
# we can check the signal because the 2nd executor will set it
num_dequeued = 0
while num_dequeued < 2:
for signal in start_signals:
signal.wait(timeout=1)
if signal.is_set():
num_dequeued += 1
executors = []
statuses = []
for handle in handles:
status = handle.get_status()
statuses.append(status.status)
executors.append(status.executor_id)
assert set(statuses) == {
WorkflowStatusString.PENDING.value,
WorkflowStatusString.ENQUEUED.value,
}
assert len(set(executors)) == 2
assert "local" in executors

# Now check in the DB that global concurrency is met
with dbos._sys_db.engine.begin() as conn:
query = (
sa.select(sa.func.count())
.select_from(SystemSchema.workflow_status)
.where(
SystemSchema.workflow_status.c.status
== WorkflowStatusString.PENDING.value
)
)
row = conn.execute(query).fetchone()

# Enqueue N tasks but ensure this worker cannot dequeue
assert row is not None, "Query returned no results"
count = row[0]
assert (
count == global_concurrency_limit
), f"Expected {global_concurrency_limit} workflows, found {count}"

queue = Queue("test_queue", limiter={"limit": 0, "period": 1})
for i in range(0, 10):
queue.enqueue(worker_concurrency_test_workflow)
# Signal the workers they can move on
for signal in end_signals:
signal.set()

for process in processes:
process.join()

# Verify all queue entries eventually get cleaned up.
assert queue_entries_are_cleaned_up(dbos)


# Test error cases where we have duplicated workflows starting with the same workflow ID.
def test_duplicate_workflow_id(dbos: DBOS, caplog: pytest.LogCaptureFixture) -> None:
Expand Down Expand Up @@ -661,7 +774,9 @@ def blocked_workflow(i: int) -> None:
def noop() -> None:
pass

queue = Queue("test_queue", concurrency=2)
queue = Queue(
"test_queue", worker_concurrency=2
) # covers global concurrency limit because we have a single process
handle1 = queue.enqueue(blocked_workflow, 0)
handle2 = queue.enqueue(blocked_workflow, 1)
handle3 = queue.enqueue(noop)
Expand Down