Skip to content

Commit 4ddc5bf

Browse files
authored
Merge branch 'main' into alex/operation.cost
2 parents 3f2b5c5 + ca079f5 commit 4ddc5bf

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import anyio
1313
from opentelemetry.trace import NoOpTracer, use_span
1414
from pydantic.json_schema import GenerateJsonSchema
15-
from typing_extensions import TypeVar, deprecated
15+
from typing_extensions import Self, TypeVar, deprecated
1616

1717
from pydantic_graph import Graph
1818

@@ -1355,7 +1355,7 @@ def _prepare_output_schema(
13551355

13561356
return schema # pyright: ignore[reportReturnType]
13571357

1358-
async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
1358+
async def __aenter__(self) -> Self:
13591359
"""Enter the agent context.
13601360
13611361
This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pydantic_ai.tools import RunContext, ToolDefinition
2222

23+
from .direct import model_request
2324
from .toolsets.abstract import AbstractToolset, ToolsetTool
2425

2526
try:
@@ -329,11 +330,7 @@ async def _sampling_callback(
329330
if stop_sequences := params.stopSequences: # pragma: no branch
330331
model_settings['stop_sequences'] = stop_sequences
331332

332-
model_response = await self.sampling_model.request(
333-
pai_messages,
334-
model_settings,
335-
models.ModelRequestParameters(),
336-
)
333+
model_response = await model_request(self.sampling_model, pai_messages, model_settings=model_settings)
337334
return mcp_types.CreateMessageResult(
338335
role='assistant',
339336
content=_mcp.map_from_model_response(model_response),

tests/test_live.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,30 @@ def gemini(_: httpx.AsyncClient, _tmp_path: Path) -> Model:
3434
return GoogleModel('gemini-1.5-pro')
3535

3636

37-
def vertexai(_: httpx.AsyncClient, tmp_path: Path) -> Model:
37+
def vertexai(http_client: httpx.AsyncClient, tmp_path: Path) -> Model:
3838
from google.oauth2 import service_account
3939

4040
from pydantic_ai.models.google import GoogleModel
4141
from pydantic_ai.providers.google import GoogleProvider
4242

43-
service_account_content = os.environ['GOOGLE_SERVICE_ACCOUNT_CONTENT']
44-
project_id = json.loads(service_account_content)['project_id']
45-
service_account_path = tmp_path / 'service_account.json'
46-
service_account_path.write_text(service_account_content)
43+
if service_account_path := os.environ.get('GOOGLE_APPLICATION_CREDENTIALS'):
44+
project_id = json.loads(Path(service_account_path).read_text())['project_id']
45+
elif service_account_content := os.environ.get('GOOGLE_SERVICE_ACCOUNT_CONTENT'):
46+
project_id = json.loads(service_account_content)['project_id']
47+
service_account_path = tmp_path / 'service_account.json'
48+
service_account_path.write_text(service_account_content)
49+
else:
50+
pytest.skip(
51+
'VertexAI live test requires GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_SERVICE_ACCOUNT_CONTENT to be set'
52+
)
4753

4854
credentials = service_account.Credentials.from_service_account_file( # type: ignore[reportUnknownReturnType]
4955
service_account_path,
5056
scopes=['https://www.googleapis.com/auth/cloud-platform'],
5157
)
5258
provider = GoogleProvider(credentials=credentials, project=project_id)
53-
return GoogleModel('gemini-1.5-flash', provider=provider)
59+
provider.client.aio._api_client._async_httpx_client = http_client # type: ignore
60+
return GoogleModel('gemini-2.0-flash', provider=provider)
5461

5562

5663
def groq(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:

0 commit comments

Comments
 (0)