Skip to content

Commit b619b19

Browse files
[plan][python] Built-in actions support async execution.
Co-authored-by: Shekharrajak <shekharrajak@live.com>
1 parent e0b644c commit b619b19

File tree

5 files changed

+201
-28
lines changed

5 files changed

+201
-28
lines changed

python/flink_agents/api/core_options.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,25 @@ class AgentConfigOptions(metaclass=AgentConfigOptionsMeta):
101101
config_type=int,
102102
default=3,
103103
)
104+
105+
106+
class AgentExecutionOptions(metaclass=AgentConfigOptionsMeta):
107+
"""Execution options for Flink Agents."""
108+
109+
CHAT_ASYNC = ConfigOption(
110+
key="chat.async",
111+
config_type=bool,
112+
default=True,
113+
)
114+
115+
TOOL_CALL_ASYNC = ConfigOption(
116+
key="tool-call.async",
117+
config_type=bool,
118+
default=True,
119+
)
120+
121+
RAG_ASYNC = ConfigOption(
122+
key="rag.async",
123+
config_type=bool,
124+
default=True,
125+
)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#################################################################################
18+
import time
19+
import uuid
20+
from typing import Any, Dict, Sequence
21+
22+
from pyflink.datastream import StreamExecutionEnvironment
23+
from typing_extensions import override
24+
25+
from flink_agents.api.agents.agent import Agent
26+
from flink_agents.api.chat_message import ChatMessage, MessageRole
27+
from flink_agents.api.chat_models.chat_model import BaseChatModelSetup
28+
from flink_agents.api.decorators import action, chat_model_setup, tool
29+
from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent
30+
from flink_agents.api.events.event import InputEvent, OutputEvent
31+
from flink_agents.api.execution_environment import AgentsExecutionEnvironment
32+
from flink_agents.api.resource import ResourceDescriptor
33+
from flink_agents.api.runner_context import RunnerContext
34+
from flink_agents.api.tools.tool import ToolType
35+
36+
37+
class SlowMockChatModel(BaseChatModelSetup):
38+
"""Mock ChatModel with slow connection."""
39+
40+
@property
41+
def model_kwargs(self) -> Dict[str, Any]: # noqa: D102
42+
return {}
43+
44+
@override
45+
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage:
46+
time.sleep(5) # Simulate network delay
47+
if "sum" in messages[-1].content:
48+
input = messages[-1].content
49+
function = {"name": "add", "arguments": {"a": 1, "b": 2}}
50+
tool_call = {
51+
"id": uuid.uuid4(),
52+
"type": ToolType.FUNCTION,
53+
"function": function,
54+
}
55+
return ChatMessage(
56+
role=MessageRole.ASSISTANT, content=input, tool_calls=[tool_call]
57+
)
58+
else:
59+
content = "\n".join([message.content for message in messages])
60+
return ChatMessage(role=MessageRole.ASSISTANT, content=content)
61+
62+
63+
class AsyncTestAgent(Agent):
64+
"""Agent for testing async execution."""
65+
66+
@chat_model_setup
67+
@staticmethod
68+
def slow_chat_model() -> ResourceDescriptor: # noqa: D102
69+
return ResourceDescriptor(
70+
clazz=SlowMockChatModel,
71+
connection="placement",
72+
tools=["add"],
73+
)
74+
75+
@tool
76+
@staticmethod
77+
def add(a: int, b: int) -> int:
78+
"""Calculate the sum of a and b."""
79+
time.sleep(5) # Simulate slow tool execution
80+
return a + b
81+
82+
@action(InputEvent)
83+
@staticmethod
84+
def process_input(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102
85+
input = event.input
86+
ctx.send_event(
87+
ChatRequestEvent(
88+
model="slow_chat_model",
89+
messages=[
90+
ChatMessage(
91+
role=MessageRole.USER, content=input, extra_args={"task": input}
92+
)
93+
],
94+
)
95+
)
96+
97+
@action(ChatResponseEvent)
98+
@staticmethod
99+
def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: # noqa: D102
100+
input = event.response
101+
ctx.send_event(OutputEvent(output=input.content))
102+
103+
104+
def test_built_in_actions_async_execution() -> None:
105+
"""Test that built-in actions use async execution correctly.
106+
107+
This test verifies that chat_model_action and tool_call_action work
108+
correctly with async execution, ensuring backward compatibility.
109+
"""
110+
env = StreamExecutionEnvironment.get_execution_environment()
111+
env.set_parallelism(1)
112+
113+
input_stream = env.from_collection(
114+
["calculate the sum of 1 and 2" for _ in range(10)],
115+
)
116+
117+
agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
118+
output_datastream = (
119+
agents_env.from_datastream(
120+
input=input_stream, key_selector=lambda x: uuid.uuid4()
121+
)
122+
.apply(AsyncTestAgent())
123+
.to_datastream()
124+
)
125+
126+
output_datastream.print()
127+
128+
# Measure execution time to verify async doesn't block
129+
start_time = time.time()
130+
agents_env.execute()
131+
execution_time = time.time() - start_time
132+
133+
assert execution_time < 50

python/flink_agents/plan/actions/chat_model_action.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from flink_agents.api.agents.agent import STRUCTURED_OUTPUT
2929
from flink_agents.api.agents.react_agent import OutputSchema
3030
from flink_agents.api.chat_message import ChatMessage, MessageRole
31-
from flink_agents.api.core_options import AgentConfigOptions, ErrorHandlingStrategy
31+
from flink_agents.api.core_options import (
32+
AgentConfigOptions,
33+
AgentExecutionOptions,
34+
ErrorHandlingStrategy,
35+
)
3236
from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent
3337
from flink_agents.api.events.event import Event
3438
from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent
@@ -80,6 +84,7 @@ def _update_tool_call_context(
8084
sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
8185
return tool_call_context[initial_request_id]
8286

87+
8388
def _save_tool_request_event_context(
8489
sensory_memory: MemoryObject,
8590
tool_request_event_id: UUID,
@@ -156,7 +161,7 @@ def _generate_structured_output(
156161
return response
157162

158163

159-
def chat(
164+
async def chat(
160165
initial_request_id: UUID,
161166
model: str,
162167
messages: List[ChatMessage],
@@ -173,16 +178,20 @@ def chat(
173178
"BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL)
174179
)
175180

181+
chat_async = ctx.config.get(AgentExecutionOptions.CHAT_ASYNC)
182+
176183
error_handling_strategy = ctx.config.get(AgentConfigOptions.ERROR_HANDLING_STRATEGY)
177184
num_retries = 0
178185
if error_handling_strategy == ErrorHandlingStrategy.RETRY:
179186
num_retries = max(0, ctx.config.get(AgentConfigOptions.MAX_RETRIES))
180187

181-
# TODO: support async execution of chat.
182188
response = None
183189
for attempt in range(num_retries + 1):
184190
try:
185-
response = chat_model.chat(messages)
191+
if chat_async:
192+
response = await ctx.execute_async(chat_model.chat, messages)
193+
else:
194+
response = chat_model.chat(messages)
186195
if output_schema is not None and len(response.tool_calls) == 0:
187196
response = _generate_structured_output(response, output_schema)
188197
break
@@ -219,9 +228,9 @@ def chat(
219228
)
220229

221230

222-
def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None:
231+
async def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None:
223232
"""Process chat request event."""
224-
chat(
233+
await chat(
225234
initial_request_id=event.id,
226235
model=event.model,
227236
messages=event.messages,
@@ -230,7 +239,7 @@ def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None:
230239
)
231240

232241

233-
def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None:
242+
async def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None:
234243
"""Organize the tool call context and return it to the LLM."""
235244
sensory_memory = ctx.sensory_memory
236245
request_id = event.request_id
@@ -242,6 +251,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None
242251
initial_request_id = tool_request_event_context["initial_request_id"]
243252

244253
# update tool call context, and get the entire chat messages.
254+
print(f"RequestId: {request_id}, Initial request_id: {initial_request_id}")
245255
messages = _update_tool_call_context(
246256
sensory_memory,
247257
initial_request_id,
@@ -258,7 +268,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None
258268
],
259269
)
260270

261-
chat(
271+
await chat(
262272
initial_request_id=initial_request_id,
263273
model=tool_request_event_context["model"],
264274
messages=messages,
@@ -267,17 +277,19 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None
267277
)
268278

269279

270-
def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None:
280+
async def process_chat_request_or_tool_response(
281+
event: Event, ctx: RunnerContext
282+
) -> None:
271283
"""Built-in action for processing a chat request or tool response.
272284
273285
This action listens to ChatRequestEvent and ToolResponseEvent, and handles
274286
the complete chat flow including tool calls. It uses sensory memory to save
275287
the tool call context, which is a dict mapping request id to chat messages.
276288
"""
277289
if isinstance(event, ChatRequestEvent):
278-
_process_chat_request(event, ctx)
290+
await _process_chat_request(event, ctx)
279291
elif isinstance(event, ToolResponseEvent):
280-
_process_tool_response(event, ctx)
292+
await _process_tool_response(event, ctx)
281293

282294

283295
CHAT_MODEL_ACTION = Action(

python/flink_agents/plan/actions/context_retrieval_action.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#################################################################################
18+
from flink_agents.api.core_options import AgentExecutionOptions
1819
from flink_agents.api.events.context_retrieval_event import (
1920
ContextRetrievalRequestEvent,
2021
ContextRetrievalResponseEvent,
@@ -27,26 +28,24 @@
2728
from flink_agents.plan.function import PythonFunction
2829

2930

30-
def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None:
31+
async def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None:
3132
"""Built-in action for processing context retrieval requests."""
3233
if isinstance(event, ContextRetrievalRequestEvent):
33-
vector_store = ctx.get_resource(
34-
event.vector_store,
35-
ResourceType.VECTOR_STORE
36-
)
34+
vector_store = ctx.get_resource(event.vector_store, ResourceType.VECTOR_STORE)
3735

38-
query = VectorStoreQuery(
39-
query_text=event.query,
40-
limit=event.max_results
41-
)
36+
query = VectorStoreQuery(query_text=event.query, limit=event.max_results)
4237

43-
result = vector_store.query(query)
38+
rag_async = ctx.config.get(AgentExecutionOptions.RAG_ASYNC)
39+
if rag_async:
40+
result = await ctx.execute_async(vector_store.query, query)
41+
else:
42+
result = vector_store.query(query)
4443

45-
ctx.send_event(ContextRetrievalResponseEvent(
46-
request_id=event.id,
47-
query=event.query,
48-
documents=result.documents
49-
))
44+
ctx.send_event(
45+
ContextRetrievalResponseEvent(
46+
request_id=event.id, query=event.query, documents=result.documents
47+
)
48+
)
5049

5150

5251
CONTEXT_RETRIEVAL_ACTION = Action(

python/flink_agents/plan/actions/tool_call_action.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#################################################################################
18+
19+
from flink_agents.api.core_options import AgentExecutionOptions
1820
from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent
1921
from flink_agents.api.resource import ResourceType
2022
from flink_agents.api.runner_context import RunnerContext
2123
from flink_agents.plan.actions.action import Action
2224
from flink_agents.plan.function import PythonFunction
2325

2426

25-
def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None:
27+
async def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None:
2628
"""Built-in action for processing tool call requests."""
29+
tool_call_async = ctx.config.get(AgentExecutionOptions.TOOL_CALL_ASYNC)
30+
2731
responses = {}
2832
external_ids = {}
2933
for tool_call in event.tool_calls:
@@ -35,7 +39,10 @@ def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None:
3539
if not tool:
3640
response = f"Tool `{name}` does not exist."
3741
else:
38-
response = tool.call(**kwargs)
42+
if tool_call_async:
43+
response = await ctx.execute_async(tool.call, **kwargs)
44+
else:
45+
response = tool.call(**kwargs)
3946
responses[id] = response
4047
external_ids[id] = external_id
4148
ctx.send_event(

0 commit comments

Comments
 (0)