Skip to content

Commit d4b5639

Browse files
committed
feat(agent, task): add concurrency limiter and message sink protocols
Implemented a concurrency registry to control concurrent task execution for agents, allowing users to set limits per agent name. Also added message sink protocols for managing sequenced task execution messages, enabling replay of agent outputs and tool calls. New files include: - agent/concurrency.py - task/message_sink.py Updated the task module to include the new protocols in exports. Total changes across 13 files with 1099 insertions.
1 parent 4c68e61 commit d4b5639

File tree

13 files changed

+1099
-4
lines changed

13 files changed

+1099
-4
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Per-Agent Concurrency Limiter for PraisonAI Agents.
2+
3+
Provides a registry-based approach to limit concurrent task execution
4+
per agent name. No Agent constructor param bloat — uses a global registry.
5+
6+
Usage:
7+
from praisonaiagents.agent.concurrency import get_concurrency_registry
8+
9+
registry = get_concurrency_registry()
10+
registry.set_limit("researcher", 2) # max 2 concurrent tasks
11+
12+
# In async code:
13+
async with registry.throttle("researcher"):
14+
await do_work()
15+
16+
# Or manual:
17+
await registry.acquire("researcher")
18+
try:
19+
await do_work()
20+
finally:
21+
registry.release("researcher")
22+
"""
23+
24+
import asyncio
25+
import threading
26+
from contextlib import asynccontextmanager
27+
from typing import Dict, Optional
28+
29+
from praisonaiagents._logging import get_logger
30+
31+
logger = get_logger(__name__)
32+
33+
34+
class ConcurrencyRegistry:
35+
"""Registry for per-agent concurrency limits.
36+
37+
Thread-safe. Each agent name maps to an asyncio.Semaphore.
38+
Limit of 0 means unlimited (no throttling).
39+
"""
40+
41+
def __init__(self, default_limit: int = 0):
42+
self._default_limit = default_limit
43+
self._limits: Dict[str, int] = {}
44+
self._semaphores: Dict[str, asyncio.Semaphore] = {}
45+
self._lock = threading.Lock()
46+
47+
def set_limit(self, agent_name: str, max_concurrent: int) -> None:
48+
"""Set concurrency limit for an agent.
49+
50+
Args:
51+
agent_name: Agent identifier
52+
max_concurrent: Max concurrent tasks (0 = unlimited)
53+
"""
54+
with self._lock:
55+
self._limits[agent_name] = max_concurrent
56+
# Reset semaphore so next acquire creates a fresh one
57+
self._semaphores.pop(agent_name, None)
58+
59+
def get_limit(self, agent_name: str) -> int:
60+
"""Get concurrency limit for an agent."""
61+
with self._lock:
62+
return self._limits.get(agent_name, self._default_limit)
63+
64+
def remove_limit(self, agent_name: str) -> None:
65+
"""Remove concurrency limit for an agent (reverts to default)."""
66+
with self._lock:
67+
self._limits.pop(agent_name, None)
68+
self._semaphores.pop(agent_name, None)
69+
70+
def _get_semaphore(self, agent_name: str) -> Optional[asyncio.Semaphore]:
71+
"""Get or create semaphore for agent. Returns None if unlimited."""
72+
with self._lock:
73+
limit = self._limits.get(agent_name, self._default_limit)
74+
if limit <= 0:
75+
return None
76+
if agent_name not in self._semaphores:
77+
self._semaphores[agent_name] = asyncio.Semaphore(limit)
78+
return self._semaphores[agent_name]
79+
80+
async def acquire(self, agent_name: str) -> None:
81+
"""Acquire concurrency slot for agent. No-op if unlimited."""
82+
sem = self._get_semaphore(agent_name)
83+
if sem is not None:
84+
await sem.acquire()
85+
86+
def acquire_sync(self, agent_name: str) -> None:
87+
"""Synchronous acquire — for non-async code paths.
88+
89+
Note: This creates/reuses an event loop internally.
90+
Prefer async acquire() when possible.
91+
"""
92+
sem = self._get_semaphore(agent_name)
93+
if sem is None:
94+
return
95+
try:
96+
asyncio.get_running_loop()
97+
# If we're in an async context, we can't block
98+
# Just try_acquire or no-op with warning
99+
if not sem._value > 0:
100+
logger.warning(
101+
f"Sync acquire for '{agent_name}' while async loop running and semaphore full. "
102+
f"Consider using async acquire() instead."
103+
)
104+
# Decrement manually for sync context
105+
sem._value = max(0, sem._value - 1)
106+
except RuntimeError:
107+
# No running loop — safe to use asyncio.run
108+
asyncio.get_event_loop().run_until_complete(sem.acquire())
109+
110+
def release(self, agent_name: str) -> None:
111+
"""Release concurrency slot for agent. No-op if unlimited."""
112+
with self._lock:
113+
sem = self._semaphores.get(agent_name)
114+
if sem is not None:
115+
try:
116+
sem.release()
117+
except ValueError:
118+
pass # Already fully released
119+
120+
@asynccontextmanager
121+
async def throttle(self, agent_name: str):
122+
"""Async context manager for throttled execution.
123+
124+
Usage:
125+
async with registry.throttle("agent_name"):
126+
await do_work()
127+
"""
128+
await self.acquire(agent_name)
129+
try:
130+
yield
131+
finally:
132+
self.release(agent_name)
133+
134+
135+
# Singleton
136+
_registry: Optional[ConcurrencyRegistry] = None
137+
_registry_lock = threading.Lock()
138+
139+
140+
def get_concurrency_registry() -> ConcurrencyRegistry:
141+
"""Get the global concurrency registry singleton."""
142+
global _registry
143+
if _registry is None:
144+
with _registry_lock:
145+
if _registry is None:
146+
_registry = ConcurrencyRegistry()
147+
return _registry
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
"""Task module for AI agent tasks"""
22
from .task import Task
3+
from .protocols import TaskStatus, TaskLifecycleManager, InvalidTransitionError, TaskLifecycleProtocol
4+
from .message_sink import (
5+
TaskMessage, TaskMessageSinkProtocol, NoOpTaskMessageSink,
6+
InMemoryTaskMessageSink, TaskMessageEmitter,
7+
)
38

4-
__all__ = ['Task']
9+
__all__ = [
10+
'Task', 'TaskStatus', 'TaskLifecycleManager', 'InvalidTransitionError', 'TaskLifecycleProtocol',
11+
'TaskMessage', 'TaskMessageSinkProtocol', 'NoOpTaskMessageSink',
12+
'InMemoryTaskMessageSink', 'TaskMessageEmitter',
13+
]
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Task Message Sink Protocols for PraisonAI Agents.
2+
3+
Provides pluggable persistence for sequenced task execution messages.
4+
Matches Multica's ReportTaskMessages pattern but as a protocol,
5+
enabling replay of agent output, tool calls, errors, and status changes.
6+
7+
Usage:
8+
from praisonaiagents.task.message_sink import InMemoryTaskMessageSink, TaskMessageEmitter
9+
10+
sink = InMemoryTaskMessageSink()
11+
emitter = TaskMessageEmitter(task_id="t1", sink=sink, agent_name="researcher")
12+
emitter.emit("agent_output", "Hello world")
13+
emitter.emit("tool_call", "search_web('AI trends')")
14+
15+
# Replay
16+
messages = sink.replay("t1")
17+
"""
18+
19+
from dataclasses import dataclass, field
20+
from datetime import datetime
21+
from typing import Protocol, runtime_checkable, List, Optional, Dict, Any
22+
23+
24+
@dataclass
25+
class TaskMessage:
26+
"""A single sequenced message from task execution.
27+
28+
Attributes:
29+
task_id: Task this message belongs to
30+
seq_num: Sequence number for ordering (0-indexed)
31+
msg_type: Message type (agent_output, tool_call, tool_result, error, status)
32+
content: Message content
33+
agent_name: Agent that produced this message
34+
metadata: Optional extra metadata
35+
timestamp: ISO timestamp
36+
"""
37+
task_id: str
38+
seq_num: int
39+
msg_type: str
40+
content: str
41+
agent_name: Optional[str] = None
42+
metadata: Dict[str, Any] = field(default_factory=dict)
43+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
44+
45+
def to_dict(self) -> Dict[str, Any]:
46+
"""Convert to dictionary."""
47+
return {
48+
"task_id": self.task_id,
49+
"seq_num": self.seq_num,
50+
"msg_type": self.msg_type,
51+
"content": self.content,
52+
"agent_name": self.agent_name,
53+
"metadata": self.metadata,
54+
"timestamp": self.timestamp,
55+
}
56+
57+
58+
@runtime_checkable
59+
class TaskMessageSinkProtocol(Protocol):
60+
"""Protocol for persisting task execution messages.
61+
62+
Implementations can write to databases, files, WebSockets, etc.
63+
"""
64+
65+
def emit(self, message: TaskMessage) -> None:
66+
"""Persist a single task message."""
67+
...
68+
69+
def replay(self, task_id: str) -> List[TaskMessage]:
70+
"""Replay all messages for a task, ordered by seq_num."""
71+
...
72+
73+
74+
class NoOpTaskMessageSink:
75+
"""Default sink that does nothing. Zero overhead."""
76+
77+
def emit(self, message: TaskMessage) -> None:
78+
pass
79+
80+
def replay(self, task_id: str) -> List[TaskMessage]:
81+
return []
82+
83+
84+
class InMemoryTaskMessageSink:
85+
"""In-memory sink for testing and debugging.
86+
87+
Stores all messages in a list for inspection and replay.
88+
"""
89+
90+
def __init__(self):
91+
self.messages: List[TaskMessage] = []
92+
93+
def emit(self, message: TaskMessage) -> None:
94+
self.messages.append(message)
95+
96+
def replay(self, task_id: str) -> List[TaskMessage]:
97+
"""Replay messages for a task, ordered by seq_num."""
98+
task_messages = [m for m in self.messages if m.task_id == task_id]
99+
return sorted(task_messages, key=lambda m: m.seq_num)
100+
101+
def clear(self) -> None:
102+
"""Clear all messages."""
103+
self.messages.clear()
104+
105+
106+
class TaskMessageEmitter:
107+
"""Convenience emitter that auto-sequences messages for a task.
108+
109+
Maintains a per-task sequence counter so callers don't need
110+
to track seq_num manually.
111+
112+
Usage:
113+
emitter = TaskMessageEmitter(task_id="t1", sink=sink)
114+
emitter.emit("agent_output", "Hello") # seq_num=0
115+
emitter.emit("tool_call", "search()") # seq_num=1
116+
"""
117+
118+
def __init__(
119+
self,
120+
task_id: str,
121+
sink: TaskMessageSinkProtocol,
122+
agent_name: Optional[str] = None,
123+
):
124+
self.task_id = task_id
125+
self.sink = sink
126+
self.agent_name = agent_name
127+
self._seq_num = 0
128+
129+
def emit(
130+
self,
131+
msg_type: str,
132+
content: str,
133+
agent_name: Optional[str] = None,
134+
metadata: Optional[Dict[str, Any]] = None,
135+
) -> None:
136+
"""Emit a message with auto-incremented sequence number."""
137+
message = TaskMessage(
138+
task_id=self.task_id,
139+
seq_num=self._seq_num,
140+
msg_type=msg_type,
141+
content=content,
142+
agent_name=agent_name or self.agent_name,
143+
metadata=metadata or {},
144+
)
145+
self.sink.emit(message)
146+
self._seq_num += 1

0 commit comments

Comments
 (0)