Skip to content

Commit 08062b8

Browse files
authored
Fix CrewAI OSS concurrent execution (ag-ui-protocol#131)
* fix concurrent execution problems by adding a global queue * bump version
1 parent 7278899 commit 08062b8

File tree

2 files changed

+159
-115
lines changed

2 files changed

+159
-115
lines changed

typescript-sdk/integrations/crewai/python/ag_ui_crewai/endpoint.py

Lines changed: 158 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
"""
44
import copy
55
import asyncio
6-
from typing import List
6+
from typing import List, Optional
77
from fastapi import FastAPI, Request
88
from fastapi.responses import StreamingResponse
99

1010
from crewai.utilities.events import (
11-
crewai_event_bus,
1211
FlowStartedEvent,
1312
FlowFinishedEvent,
1413
MethodExecutionStartedEvent,
1514
MethodExecutionFinishedEvent,
1615
)
1716
from crewai.flow.flow import Flow
17+
from crewai.utilities.events.base_event_listener import BaseEventListener
1818
from crewai import Crew
1919

2020
from ag_ui.core import (
@@ -47,10 +47,152 @@
4747
from .sdk import litellm_messages_to_ag_ui_messages
4848
from .crews import ChatWithCrewFlow
4949

50+
QUEUES = {}
51+
QUEUES_LOCK = asyncio.Lock()
52+
53+
54+
async def create_queue(flow: object) -> asyncio.Queue:
55+
"""Create a queue for a flow."""
56+
queue_id = id(flow)
57+
async with QUEUES_LOCK:
58+
queue = asyncio.Queue()
59+
QUEUES[queue_id] = queue
60+
return queue
61+
62+
63+
def get_queue(flow: object) -> Optional[asyncio.Queue]:
64+
"""Get the queue for a flow."""
65+
queue_id = id(flow)
66+
# not using a lock here should be fine
67+
return QUEUES.get(queue_id)
68+
69+
async def delete_queue(flow: object) -> None:
70+
"""Delete the queue for a flow."""
71+
queue_id = id(flow)
72+
async with QUEUES_LOCK:
73+
if queue_id in QUEUES:
74+
del QUEUES[queue_id]
75+
76+
GLOBAL_EVENT_LISTENER = None
77+
78+
class FastAPICrewFlowEventListener(BaseEventListener):
79+
"""FastAPI CrewFlow event listener"""
80+
81+
def setup_listeners(self, crewai_event_bus):
82+
"""Setup listeners for the FastAPI CrewFlow event listener"""
83+
@crewai_event_bus.on(FlowStartedEvent)
84+
def _(source, event): # pylint: disable=unused-argument
85+
queue = get_queue(source)
86+
if queue is not None:
87+
queue.put_nowait(
88+
RunStartedEvent(
89+
type=EventType.RUN_STARTED,
90+
# will be replaced by the correct thread_id/run_id when sending the event
91+
thread_id="?",
92+
run_id="?",
93+
),
94+
)
95+
@crewai_event_bus.on(FlowFinishedEvent)
96+
def _(source, event): # pylint: disable=unused-argument
97+
queue = get_queue(source)
98+
if queue is not None:
99+
queue.put_nowait(
100+
RunFinishedEvent(
101+
type=EventType.RUN_FINISHED,
102+
thread_id="?",
103+
run_id="?",
104+
),
105+
)
106+
queue.put_nowait(None)
107+
@crewai_event_bus.on(MethodExecutionStartedEvent)
108+
def _(source, event):
109+
queue = get_queue(source)
110+
if queue is not None:
111+
queue.put_nowait(
112+
StepStartedEvent(
113+
type=EventType.STEP_STARTED,
114+
step_name=event.method_name
115+
)
116+
)
117+
@crewai_event_bus.on(MethodExecutionFinishedEvent)
118+
def _(source, event):
119+
queue = get_queue(source)
120+
if queue is not None:
121+
messages = litellm_messages_to_ag_ui_messages(source.state.messages)
122+
123+
queue.put_nowait(
124+
MessagesSnapshotEvent(
125+
type=EventType.MESSAGES_SNAPSHOT,
126+
messages=messages
127+
)
128+
)
129+
queue.put_nowait(
130+
StateSnapshotEvent(
131+
type=EventType.STATE_SNAPSHOT,
132+
snapshot=source.state
133+
)
134+
)
135+
queue.put_nowait(
136+
StepFinishedEvent(
137+
type=EventType.STEP_FINISHED,
138+
step_name=event.method_name
139+
)
140+
)
141+
@crewai_event_bus.on(BridgedTextMessageChunkEvent)
142+
def _(source, event):
143+
queue = get_queue(source)
144+
if queue is not None:
145+
queue.put_nowait(
146+
TextMessageChunkEvent(
147+
type=EventType.TEXT_MESSAGE_CHUNK,
148+
message_id=event.message_id,
149+
role=event.role,
150+
delta=event.delta,
151+
)
152+
)
153+
@crewai_event_bus.on(BridgedToolCallChunkEvent)
154+
def _(source, event):
155+
queue = get_queue(source)
156+
if queue is not None:
157+
queue.put_nowait(
158+
ToolCallChunkEvent(
159+
type=EventType.TOOL_CALL_CHUNK,
160+
tool_call_id=event.tool_call_id,
161+
tool_call_name=event.tool_call_name,
162+
delta=event.delta,
163+
)
164+
)
165+
@crewai_event_bus.on(BridgedCustomEvent)
166+
def _(source, event):
167+
queue = get_queue(source)
168+
if queue is not None:
169+
queue.put_nowait(
170+
CustomEvent(
171+
type=EventType.CUSTOM,
172+
name=event.name,
173+
value=event.value
174+
)
175+
)
176+
@crewai_event_bus.on(BridgedStateSnapshotEvent)
177+
def _(source, event):
178+
queue = get_queue(source)
179+
if queue is not None:
180+
queue.put_nowait(
181+
StateSnapshotEvent(
182+
type=EventType.STATE_SNAPSHOT,
183+
snapshot=event.snapshot
184+
)
185+
)
50186

51187
def add_crewai_flow_fastapi_endpoint(app: FastAPI, flow: Flow, path: str = "/"):
52188
"""Adds a CrewAI endpoint to the FastAPI app."""
189+
global GLOBAL_EVENT_LISTENER # pylint: disable=global-statement
53190

191+
# Set up the global event listener singleton
192+
# we are doing this here because calling add_crewai_flow_fastapi_endpoint is a clear indicator
193+
# that we are not running on CrewAI enterprise
194+
if GLOBAL_EVENT_LISTENER is None:
195+
GLOBAL_EVENT_LISTENER = FastAPICrewFlowEventListener()
54196

55197
@app.post(path)
56198
async def agentic_chat_endpoint(input_data: RunAgentInput, request: Request):
@@ -71,120 +213,21 @@ async def agentic_chat_endpoint(input_data: RunAgentInput, request: Request):
71213
)
72214

73215
async def event_generator():
74-
queue = asyncio.Queue()
216+
queue = await create_queue(flow_copy)
75217
token = flow_context.set(flow_copy)
76218
try:
77-
with crewai_event_bus.scoped_handlers():
78-
79-
@crewai_event_bus.on(FlowStartedEvent)
80-
def _(source, event): # pylint: disable=unused-argument
81-
if source == flow_copy:
82-
queue.put_nowait(
83-
RunStartedEvent(
84-
type=EventType.RUN_STARTED,
85-
thread_id=input_data.thread_id,
86-
run_id=input_data.run_id,
87-
),
88-
)
89-
90-
@crewai_event_bus.on(FlowFinishedEvent)
91-
def _(source, event): # pylint: disable=unused-argument
92-
if source == flow_copy:
93-
queue.put_nowait(
94-
RunFinishedEvent(
95-
type=EventType.RUN_FINISHED,
96-
thread_id=input_data.thread_id,
97-
run_id=input_data.run_id,
98-
),
99-
)
100-
queue.put_nowait(None)
101-
102-
@crewai_event_bus.on(MethodExecutionStartedEvent)
103-
def _(source, event):
104-
if source == flow_copy:
105-
queue.put_nowait(
106-
StepStartedEvent(
107-
type=EventType.STEP_STARTED,
108-
step_name=event.method_name
109-
)
110-
)
111-
112-
@crewai_event_bus.on(MethodExecutionFinishedEvent)
113-
def _(source, event):
114-
if source == flow_copy:
115-
messages = litellm_messages_to_ag_ui_messages(source.state.messages)
116-
117-
queue.put_nowait(
118-
MessagesSnapshotEvent(
119-
type=EventType.MESSAGES_SNAPSHOT,
120-
messages=messages
121-
)
122-
)
123-
queue.put_nowait(
124-
StateSnapshotEvent(
125-
type=EventType.STATE_SNAPSHOT,
126-
snapshot=source.state
127-
)
128-
)
129-
queue.put_nowait(
130-
StepFinishedEvent(
131-
type=EventType.STEP_FINISHED,
132-
step_name=event.method_name
133-
)
134-
)
135-
136-
@crewai_event_bus.on(BridgedTextMessageChunkEvent)
137-
def _(source, event):
138-
if source == flow_copy:
139-
queue.put_nowait(
140-
TextMessageChunkEvent(
141-
type=EventType.TEXT_MESSAGE_CHUNK,
142-
message_id=event.message_id,
143-
role=event.role,
144-
delta=event.delta,
145-
)
146-
)
147-
148-
@crewai_event_bus.on(BridgedToolCallChunkEvent)
149-
def _(source, event):
150-
if source == flow_copy:
151-
queue.put_nowait(
152-
ToolCallChunkEvent(
153-
type=EventType.TOOL_CALL_CHUNK,
154-
tool_call_id=event.tool_call_id,
155-
tool_call_name=event.tool_call_name,
156-
delta=event.delta,
157-
)
158-
)
159-
160-
@crewai_event_bus.on(BridgedCustomEvent)
161-
def _(source, event):
162-
if source == flow_copy:
163-
queue.put_nowait(
164-
CustomEvent(
165-
type=EventType.CUSTOM,
166-
name=event.name,
167-
value=event.value
168-
)
169-
)
170-
171-
@crewai_event_bus.on(BridgedStateSnapshotEvent)
172-
def _(source, event):
173-
if source == flow_copy:
174-
queue.put_nowait(
175-
StateSnapshotEvent(
176-
type=EventType.STATE_SNAPSHOT,
177-
snapshot=event.snapshot
178-
)
179-
)
180-
181-
asyncio.create_task(flow_copy.kickoff_async(inputs=inputs))
182-
183-
while True:
184-
item = await queue.get()
185-
if item is None:
186-
break
187-
yield encoder.encode(item)
219+
asyncio.create_task(flow_copy.kickoff_async(inputs=inputs))
220+
221+
while True:
222+
item = await queue.get()
223+
if item is None:
224+
break
225+
226+
if item.type == EventType.RUN_STARTED or item.type == EventType.RUN_FINISHED:
227+
item.thread_id = input_data.thread_id
228+
item.run_id = input_data.run_id
229+
230+
yield encoder.encode(item)
188231

189232
except Exception as e: # pylint: disable=broad-exception-caught
190233
yield encoder.encode(
@@ -196,6 +239,7 @@ def _(source, event):
196239
)
197240
)
198241
finally:
242+
await delete_queue(flow_copy)
199243
flow_context.reset(token)
200244

201245
return StreamingResponse(event_generator(), media_type=encoder.get_content_type())

typescript-sdk/integrations/crewai/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "ag-ui-crewai"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = ""
55
authors = ["Markus Ecker <[email protected]>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)