Skip to content

Commit 2e2e0a9

Browse files
committed
refactor: remove stream_options from model adapters and consolidate test helpers
Removed stream_options parameter from CrewAI, Google ADK, and PydanticAI model adapters as it was causing issues with non-streaming requests and each framework handles usage information internally. Consolidated test helper functions from separate helpers.py file into conftest.py to improve test organization and maintainability. 移除 CrewAI、Google ADK 和 PydanticAI 模型适配器中的 stream_options 参数,因为它在非流式请求中造成问题,且每个框架都内置处理用量信息。将测试辅助函数从独立的 helpers.py 文件合并到 conftest.py 中,以改善测试组织和可维护性。 Change-Id: I0271c6f8340a52f74b2053a570e3b3c0789933d5
1 parent 8bc2368 commit 2e2e0a9

File tree

17 files changed

+3141
-1094
lines changed

17 files changed

+3141
-1094
lines changed

agentrun/integration/crewai/model_adapter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ def wrap_model(self, common_model: Any) -> Any:
1717
from crewai import LLM
1818

1919
info = common_model.get_model_info() # 确保模型可用
20+
21+
# 注意:不在此处设置 stream_options,因为:
22+
# 1. CrewAI 内部决定是否使用流式请求
23+
# 2. 在非流式请求中传递 stream_options 不符合 OpenAI API 规范
24+
# 3. CrewAI 会自行处理 usage 信息
2025
return LLM(
2126
api_key=info.api_key,
2227
model=f"{info.provider or 'openai'}/{info.model}",
2328
base_url=info.base_url,
2429
default_headers=info.headers,
25-
stream_options={"include_usage": True},
2630
# async_client=AsyncClient(headers=info.headers),
2731
)

agentrun/integration/google_adk/model_adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ def wrap_model(self, common_model: CommonModel) -> Any:
3434

3535
info = common_model.get_model_info()
3636

37+
# 注意:不在此处设置 stream_options,因为:
38+
# 1. Google ADK 内部决定是否使用流式请求
39+
# 2. 在非流式请求中传递 stream_options 不符合 OpenAI API 规范
40+
# 3. Google ADK 会自行处理 usage 信息
3741
return LiteLlm(
3842
model=f"{info.provider or 'openai'}/{info.model}",
3943
api_base=info.base_url,
4044
api_key=info.api_key,
4145
extra_headers=info.headers,
42-
stream_options={"include_usage": True},
4346
)

agentrun/integration/pydantic_ai/model_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def wrap_model(self, common_model: CommonModel) -> Any:
1717
try:
1818
from pydantic_ai.models.openai import OpenAIChatModel
1919
from pydantic_ai.providers.openai import OpenAIProvider
20-
from pydantic_ai.settings import ModelSettings
2120
except Exception as e:
2221
raise ImportError(
2322
"PydanticAI is not installed. "
@@ -28,16 +27,17 @@ def wrap_model(self, common_model: CommonModel) -> Any:
2827

2928
info = common_model.get_model_info()
3029

30+
# 注意:不在此处设置 stream_options,因为:
31+
# 1. run_sync() 使用非流式请求,不需要 stream_options
32+
# 2. run_stream() 使用流式请求,PydanticAI 会自行处理 usage 信息
33+
# 3. 在非流式请求中传递 stream_options 不符合 OpenAI API 规范
3134
return OpenAIChatModel(
3235
info.model or "",
3336
provider=OpenAIProvider(
3437
base_url=info.base_url,
3538
api_key=info.api_key,
3639
http_client=AsyncClient(headers=info.headers),
3740
),
38-
settings=ModelSettings(
39-
extra_body={"stream_options": {"include_usage": True}}
40-
),
4141
)
4242

4343

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""Integration 测试基类和统一响应模型
2+
3+
为所有 integration 测试提供统一的基类,屏蔽框架特定逻辑,
4+
提供类似 AgentServer.invoke 的统一调用接口。
5+
6+
使用方式:
7+
class TestLangChain(IntegrationTestBase):
8+
def create_agent(self, model, tools=None, system_prompt="..."):
9+
...
10+
11+
def invoke(self, agent, message):
12+
...
13+
14+
async def ainvoke(self, agent, message):
15+
...
16+
"""
17+
18+
from abc import ABC, abstractmethod
19+
from dataclasses import dataclass, field
20+
from typing import Any, Dict, Iterator, List, Optional
21+
22+
from agentrun.integration.builtin.model import CommonModel
23+
from agentrun.integration.utils.tool import CommonToolSet
24+
25+
26+
@dataclass
27+
class ToolCallInfo:
28+
"""工具调用信息"""
29+
30+
name: str
31+
arguments: Dict[str, Any]
32+
id: str
33+
result: Optional[str] = None
34+
35+
36+
@dataclass
37+
class IntegrationTestResult:
38+
"""统一的 Integration 测试结果
39+
40+
将不同框架的响应格式统一为标准格式,便于测试验证。
41+
"""
42+
43+
final_text: str
44+
"""最终文本响应"""
45+
46+
tool_calls: List[ToolCallInfo] = field(default_factory=list)
47+
"""所有工具调用信息"""
48+
49+
messages: List[Dict[str, Any]] = field(default_factory=list)
50+
"""完整消息历史(框架特定格式)"""
51+
52+
raw_response: Any = None
53+
"""原始框架响应"""
54+
55+
def has_tool_calls(self) -> bool:
56+
"""是否有工具调用"""
57+
return len(self.tool_calls) > 0
58+
59+
def get_tool_call(self, name: str) -> Optional[ToolCallInfo]:
60+
"""获取指定名称的工具调用"""
61+
for tc in self.tool_calls:
62+
if tc.name == name:
63+
return tc
64+
return None
65+
66+
67+
@dataclass
68+
class StreamChunk:
69+
"""流式输出的单个块"""
70+
71+
content: Optional[str] = None
72+
"""文本内容"""
73+
74+
tool_call_id: Optional[str] = None
75+
"""工具调用 ID"""
76+
77+
tool_call_name: Optional[str] = None
78+
"""工具调用名称"""
79+
80+
tool_call_args_delta: Optional[str] = None
81+
"""工具调用参数增量"""
82+
83+
is_final: bool = False
84+
"""是否是最后一个块"""
85+
86+
87+
class IntegrationTestBase(ABC):
88+
"""Integration 测试基类
89+
90+
每个框架的测试类需要继承此基类并实现以下抽象方法:
91+
- create_agent(): 创建框架特定的 Agent
92+
- invoke(): 同步调用 Agent
93+
- ainvoke(): 异步调用 Agent
94+
- stream(): 流式调用 Agent(可选)
95+
96+
基类提供统一的测试方法和验证逻辑。
97+
"""
98+
99+
@abstractmethod
100+
def create_agent(
101+
self,
102+
model: CommonModel,
103+
tools: Optional[CommonToolSet] = None,
104+
system_prompt: str = "You are a helpful assistant.",
105+
) -> Any:
106+
"""创建框架特定的 Agent
107+
108+
Args:
109+
model: AgentRun 通用模型
110+
tools: 可选的工具集
111+
system_prompt: 系统提示词
112+
113+
Returns:
114+
框架特定的 Agent 对象
115+
"""
116+
pass
117+
118+
@abstractmethod
119+
def invoke(self, agent: Any, message: str) -> IntegrationTestResult:
120+
"""同步调用 Agent
121+
122+
Args:
123+
agent: 框架特定的 Agent 对象
124+
message: 用户消息
125+
126+
Returns:
127+
统一的测试结果
128+
"""
129+
pass
130+
131+
@abstractmethod
132+
async def ainvoke(self, agent: Any, message: str) -> IntegrationTestResult:
133+
"""异步调用 Agent
134+
135+
Args:
136+
agent: 框架特定的 Agent 对象
137+
message: 用户消息
138+
139+
Returns:
140+
统一的测试结果
141+
"""
142+
pass
143+
144+
def stream(self, agent: Any, message: str) -> Iterator[StreamChunk]:
145+
"""流式调用 Agent(可选实现)
146+
147+
Args:
148+
agent: 框架特定的 Agent 对象
149+
message: 用户消息
150+
151+
Yields:
152+
流式输出块
153+
154+
Raises:
155+
NotImplementedError: 如果框架不支持流式调用
156+
"""
157+
raise NotImplementedError(
158+
f"{self.__class__.__name__} does not support streaming"
159+
)
160+
161+
async def astream(self, agent: Any, message: str) -> Iterator[StreamChunk]:
162+
"""异步流式调用 Agent(可选实现)
163+
164+
Args:
165+
agent: 框架特定的 Agent 对象
166+
message: 用户消息
167+
168+
Yields:
169+
流式输出块
170+
171+
Raises:
172+
NotImplementedError: 如果框架不支持流式调用
173+
"""
174+
raise NotImplementedError(
175+
f"{self.__class__.__name__} does not support async streaming"
176+
)
177+
178+
# =========================================================================
179+
# 验证辅助方法
180+
# =========================================================================
181+
182+
def assert_final_text(self, result: IntegrationTestResult, expected: str):
183+
"""验证最终文本"""
184+
assert (
185+
result.final_text == expected
186+
), f"Expected '{expected}', got '{result.final_text}'"
187+
188+
def assert_final_text_contains(
189+
self, result: IntegrationTestResult, substring: str
190+
):
191+
"""验证最终文本包含指定字符串"""
192+
assert (
193+
substring in result.final_text
194+
), f"Expected '{substring}' in '{result.final_text}'"
195+
196+
def assert_tool_called(
197+
self,
198+
result: IntegrationTestResult,
199+
tool_name: str,
200+
expected_args: Optional[Dict[str, Any]] = None,
201+
):
202+
"""验证工具被调用"""
203+
tool_call = result.get_tool_call(tool_name)
204+
assert tool_call is not None, (
205+
f"Tool '{tool_name}' was not called. Called tools:"
206+
f" {[tc.name for tc in result.tool_calls]}"
207+
)
208+
209+
if expected_args is not None:
210+
assert (
211+
tool_call.arguments == expected_args
212+
), f"Expected args {expected_args}, got {tool_call.arguments}"
213+
214+
def assert_tool_not_called(
215+
self, result: IntegrationTestResult, tool_name: str
216+
):
217+
"""验证工具未被调用"""
218+
tool_call = result.get_tool_call(tool_name)
219+
assert tool_call is None, f"Tool '{tool_name}' was unexpectedly called"
220+
221+
def assert_no_tool_calls(self, result: IntegrationTestResult):
222+
"""验证没有工具调用"""
223+
assert not result.has_tool_calls(), (
224+
"Expected no tool calls, "
225+
f"got {[tc.name for tc in result.tool_calls]}"
226+
)
227+
228+
def assert_tool_call_count(
229+
self, result: IntegrationTestResult, expected_count: int
230+
):
231+
"""验证工具调用次数"""
232+
actual_count = len(result.tool_calls)
233+
assert actual_count == expected_count, (
234+
f"Expected {expected_count} tool calls, got {actual_count}. "
235+
f"Tools: {[tc.name for tc in result.tool_calls]}"
236+
)

0 commit comments

Comments
 (0)