Skip to content

Commit 004cb46

Browse files
wenjin272xintongsong
authored andcommitted
[plan][python] Built-in actions support async execution.
Co-authored-by: Shekharrajak <shekharrajak@live.com> fix fix
1 parent 9474216 commit 004cb46

File tree

5 files changed

+221
-28
lines changed

5 files changed

+221
-28
lines changed

python/flink_agents/api/core_options.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,25 @@ class AgentConfigOptions(metaclass=AgentConfigOptionsMeta):
101101
config_type=int,
102102
default=3,
103103
)
104+
105+
106+
class AgentExecutionOptions(metaclass=AgentConfigOptionsMeta):
107+
"""Execution options for Flink Agents."""
108+
109+
CHAT_ASYNC = ConfigOption(
110+
key="chat.async",
111+
config_type=bool,
112+
default=True,
113+
)
114+
115+
TOOL_CALL_ASYNC = ConfigOption(
116+
key="tool-call.async",
117+
config_type=bool,
118+
default=True,
119+
)
120+
121+
RAG_ASYNC = ConfigOption(
122+
key="rag.async",
123+
config_type=bool,
124+
default=True,
125+
)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#################################################################################
18+
import time
19+
import uuid
20+
from typing import Any, Dict, Sequence
21+
22+
from pyflink.datastream import StreamExecutionEnvironment
23+
from typing_extensions import override
24+
25+
from flink_agents.api.agents.agent import Agent
26+
from flink_agents.api.chat_message import ChatMessage, MessageRole
27+
from flink_agents.api.chat_models.chat_model import BaseChatModelSetup
28+
from flink_agents.api.decorators import action, chat_model_setup, tool
29+
from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent
30+
from flink_agents.api.events.event import InputEvent, OutputEvent
31+
from flink_agents.api.execution_environment import AgentsExecutionEnvironment
32+
from flink_agents.api.resource import ResourceDescriptor
33+
from flink_agents.api.runner_context import RunnerContext
34+
from flink_agents.api.tools.tool import ToolType
35+
36+
37+
class SlowMockChatModel(BaseChatModelSetup):
38+
"""Mock ChatModel with slow connection."""
39+
40+
@property
41+
def model_kwargs(self) -> Dict[str, Any]: # noqa: D102
42+
return {}
43+
44+
@override
45+
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage:
46+
time.sleep(5) # Simulate network delay
47+
if "sum" in messages[-1].content:
48+
input = messages[-1].content
49+
function = {"name": "add", "arguments": {"a": 1, "b": 2}}
50+
tool_call = {
51+
"id": uuid.uuid4(),
52+
"type": ToolType.FUNCTION,
53+
"function": function,
54+
}
55+
return ChatMessage(
56+
role=MessageRole.ASSISTANT, content=input, tool_calls=[tool_call]
57+
)
58+
else:
59+
content = "\n".join([message.content for message in messages])
60+
return ChatMessage(role=MessageRole.ASSISTANT, content=content)
61+
62+
63+
class AsyncTestAgent(Agent):
64+
"""Agent for testing async execution."""
65+
66+
@chat_model_setup
67+
@staticmethod
68+
def slow_chat_model() -> ResourceDescriptor: # noqa: D102
69+
return ResourceDescriptor(
70+
clazz=f"{SlowMockChatModel.__module__}.{SlowMockChatModel.__name__}",
71+
connection="placement",
72+
tools=["add"],
73+
)
74+
75+
@tool
76+
@staticmethod
77+
def add(a: int, b: int) -> int:
78+
"""Calculate the sum of a and b."""
79+
time.sleep(5) # Simulate slow tool execution
80+
return a + b
81+
82+
@action(InputEvent)
83+
@staticmethod
84+
def process_input(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102
85+
input = event.input
86+
ctx.send_event(
87+
ChatRequestEvent(
88+
model="slow_chat_model",
89+
messages=[
90+
ChatMessage(
91+
role=MessageRole.USER, content=input, extra_args={"task": input}
92+
)
93+
],
94+
)
95+
)
96+
97+
@action(ChatResponseEvent)
98+
@staticmethod
99+
def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: # noqa: D102
100+
input = event.response
101+
ctx.send_event(OutputEvent(output=input.content))
102+
103+
104+
def test_built_in_actions_async_execution() -> None:
105+
"""Test that built-in actions use async execution correctly.
106+
107+
This test verifies that chat_model_action and tool_call_action work
108+
correctly with async execution, ensuring backward compatibility.
109+
"""
110+
env = StreamExecutionEnvironment.get_execution_environment()
111+
env.set_parallelism(1)
112+
113+
input_stream = env.from_collection(
114+
["calculate the sum of 1 and 2" for _ in range(10)],
115+
)
116+
117+
agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
118+
output_datastream = (
119+
agents_env.from_datastream(
120+
input=input_stream, key_selector=lambda x: uuid.uuid4()
121+
)
122+
.apply(AsyncTestAgent())
123+
.to_datastream()
124+
)
125+
126+
output_datastream.print()
127+
128+
# Measure execution time to verify async doesn't block
129+
start_time = time.time()
130+
agents_env.execute()
131+
execution_time = time.time() - start_time
132+
133+
assert execution_time < 50

python/flink_agents/plan/actions/chat_model_action.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from flink_agents.api.agents.agent import STRUCTURED_OUTPUT
2929
from flink_agents.api.agents.react_agent import OutputSchema
3030
from flink_agents.api.chat_message import ChatMessage, MessageRole
31-
from flink_agents.api.core_options import AgentConfigOptions, ErrorHandlingStrategy
31+
from flink_agents.api.chat_models.java_chat_model import JavaChatModelSetup
32+
from flink_agents.api.core_options import (
33+
AgentConfigOptions,
34+
AgentExecutionOptions,
35+
ErrorHandlingStrategy,
36+
)
3237
from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent
3338
from flink_agents.api.events.event import Event
3439
from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent
@@ -80,6 +85,7 @@ def _update_tool_call_context(
8085
sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
8186
return tool_call_context[initial_request_id]
8287

88+
8389
def _save_tool_request_event_context(
8490
sensory_memory: MemoryObject,
8591
tool_request_event_id: UUID,
@@ -156,7 +162,7 @@ def _generate_structured_output(
156162
return response
157163

158164

159-
def chat(
165+
async def chat(
160166
initial_request_id: UUID,
161167
model: str,
162168
messages: List[ChatMessage],
@@ -173,16 +179,23 @@ def chat(
173179
"BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL)
174180
)
175181

182+
chat_async = ctx.config.get(AgentExecutionOptions.CHAT_ASYNC)
183+
# java chat model doesn't support async execution.
184+
chat_async = chat_async and not isinstance(chat_model, JavaChatModelSetup)
185+
176186
error_handling_strategy = ctx.config.get(AgentConfigOptions.ERROR_HANDLING_STRATEGY)
177187
num_retries = 0
178188
if error_handling_strategy == ErrorHandlingStrategy.RETRY:
179189
num_retries = max(0, ctx.config.get(AgentConfigOptions.MAX_RETRIES))
180190

181-
# TODO: support async execution of chat.
182191
response = None
183192
for attempt in range(num_retries + 1):
184193
try:
185-
response = chat_model.chat(messages)
194+
if chat_async:
195+
response = await ctx.durable_execute_async(chat_model.chat, messages)
196+
else:
197+
response = chat_model.chat(messages)
198+
186199
if response.extra_args.get("model_name") and response.extra_args.get("promptTokens") and response.extra_args.get("completionTokens"):
187200
chat_model._record_token_metrics(response.extra_args["model_name"], response.extra_args["promptTokens"], response.extra_args["completionTokens"])
188201
if output_schema is not None and len(response.tool_calls) == 0:
@@ -221,9 +234,9 @@ def chat(
221234
)
222235

223236

224-
def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None:
237+
async def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None:
225238
"""Process chat request event."""
226-
chat(
239+
await chat(
227240
initial_request_id=event.id,
228241
model=event.model,
229242
messages=event.messages,
@@ -232,7 +245,7 @@ def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None:
232245
)
233246

234247

235-
def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None:
248+
async def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None:
236249
"""Organize the tool call context and return it to the LLM."""
237250
sensory_memory = ctx.sensory_memory
238251
request_id = event.request_id
@@ -260,7 +273,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None
260273
],
261274
)
262275

263-
chat(
276+
await chat(
264277
initial_request_id=initial_request_id,
265278
model=tool_request_event_context["model"],
266279
messages=messages,
@@ -269,17 +282,21 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None
269282
)
270283

271284

272-
def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None:
285+
async def process_chat_request_or_tool_response(
286+
event: Event, ctx: RunnerContext
287+
) -> None:
273288
"""Built-in action for processing a chat request or tool response.
274289
275290
This action listens to ChatRequestEvent and ToolResponseEvent, and handles
276291
the complete chat flow including tool calls. It uses sensory memory to save
277292
the tool call context, which is a dict mapping request id to chat messages.
278293
"""
294+
# To avoid https://github.com/alibaba/pemja/issues/88, we log a message here.
295+
logging.debug("Processing chat request asynchronously.")
279296
if isinstance(event, ChatRequestEvent):
280-
_process_chat_request(event, ctx)
297+
await _process_chat_request(event, ctx)
281298
elif isinstance(event, ToolResponseEvent):
282-
_process_tool_response(event, ctx)
299+
await _process_tool_response(event, ctx)
283300

284301

285302
CHAT_MODEL_ACTION = Action(

python/flink_agents/plan/actions/context_retrieval_action.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,46 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#################################################################################
18+
import logging
19+
20+
from flink_agents.api.core_options import AgentExecutionOptions
1821
from flink_agents.api.events.context_retrieval_event import (
1922
ContextRetrievalRequestEvent,
2023
ContextRetrievalResponseEvent,
2124
)
2225
from flink_agents.api.events.event import Event
2326
from flink_agents.api.resource import ResourceType
2427
from flink_agents.api.runner_context import RunnerContext
28+
from flink_agents.api.vector_stores.java_vector_store import JavaVectorStore
2529
from flink_agents.api.vector_stores.vector_store import VectorStoreQuery
2630
from flink_agents.plan.actions.action import Action
2731
from flink_agents.plan.function import PythonFunction
2832

33+
_logger = logging.getLogger(__name__)
2934

30-
def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None:
35+
async def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None:
3136
"""Built-in action for processing context retrieval requests."""
3237
if isinstance(event, ContextRetrievalRequestEvent):
33-
vector_store = ctx.get_resource(
34-
event.vector_store,
35-
ResourceType.VECTOR_STORE
36-
)
38+
vector_store = ctx.get_resource(event.vector_store, ResourceType.VECTOR_STORE)
3739

38-
query = VectorStoreQuery(
39-
query_text=event.query,
40-
limit=event.max_results
41-
)
40+
query = VectorStoreQuery(query_text=event.query, limit=event.max_results)
4241

43-
result = vector_store.query(query)
42+
rag_async = ctx.config.get(AgentExecutionOptions.RAG_ASYNC)
43+
# java vector store doesn't support async execution.
44+
rag_async = rag_async and not isinstance(vector_store, JavaVectorStore)
45+
if rag_async:
46+
# To avoid https://github.com/alibaba/pemja/issues/88,
47+
# we log a message here.
48+
_logger.debug("Processing context retrieval asynchronously.")
49+
result = await ctx.durable_execute_async(vector_store.query, query)
50+
else:
51+
result = vector_store.query(query)
4452

45-
ctx.send_event(ContextRetrievalResponseEvent(
46-
request_id=event.id,
47-
query=event.query,
48-
documents=result.documents
49-
))
53+
ctx.send_event(
54+
ContextRetrievalResponseEvent(
55+
request_id=event.id, query=event.query, documents=result.documents
56+
)
57+
)
5058

5159

5260
CONTEXT_RETRIEVAL_ACTION = Action(

python/flink_agents/plan/actions/tool_call_action.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,25 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#################################################################################
18+
import logging
19+
20+
from flink_agents.api.core_options import AgentExecutionOptions
1821
from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent
1922
from flink_agents.api.resource import ResourceType
2023
from flink_agents.api.runner_context import RunnerContext
2124
from flink_agents.plan.actions.action import Action
2225
from flink_agents.plan.function import PythonFunction
2326

27+
_logger = logging.getLogger(__name__)
2428

25-
def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None:
29+
async def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None:
2630
"""Built-in action for processing tool call requests."""
31+
tool_call_async = ctx.config.get(AgentExecutionOptions.TOOL_CALL_ASYNC)
32+
33+
if tool_call_async:
34+
# To avoid https://github.com/alibaba/pemja/issues/88, we log a message here.
35+
_logger.debug("Processing tool call asynchronously.")
36+
2737
responses = {}
2838
external_ids = {}
2939
for tool_call in event.tool_calls:
@@ -35,7 +45,10 @@ def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None:
3545
if not tool:
3646
response = f"Tool `{name}` does not exist."
3747
else:
38-
response = tool.call(**kwargs)
48+
if tool_call_async:
49+
response = await ctx.durable_execute_async(tool.call, **kwargs)
50+
else:
51+
response = tool.call(**kwargs)
3952
responses[id] = response
4053
external_ids[id] = external_id
4154
ctx.send_event(

0 commit comments

Comments
 (0)