From 98acbc187f3475f7d3d9a8d6bc09479cc5845993 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 2 Sep 2025 15:23:53 +0000 Subject: [PATCH] Properly close VertexAI HTTP client at end of live test --- tests/test_live.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_live.py b/tests/test_live.py index 55ac7227d2..ae15c89ed8 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -34,23 +34,30 @@ def gemini(_: httpx.AsyncClient, _tmp_path: Path) -> Model: return GoogleModel('gemini-1.5-pro') -def vertexai(_: httpx.AsyncClient, tmp_path: Path) -> Model: +def vertexai(http_client: httpx.AsyncClient, tmp_path: Path) -> Model: from google.oauth2 import service_account from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider - service_account_content = os.environ['GOOGLE_SERVICE_ACCOUNT_CONTENT'] - project_id = json.loads(service_account_content)['project_id'] - service_account_path = tmp_path / 'service_account.json' - service_account_path.write_text(service_account_content) + if service_account_path := os.environ.get('GOOGLE_APPLICATION_CREDENTIALS'): + project_id = json.loads(Path(service_account_path).read_text())['project_id'] + elif service_account_content := os.environ.get('GOOGLE_SERVICE_ACCOUNT_CONTENT'): + project_id = json.loads(service_account_content)['project_id'] + service_account_path = tmp_path / 'service_account.json' + service_account_path.write_text(service_account_content) + else: + pytest.skip( + 'VertexAI live test requires GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_SERVICE_ACCOUNT_CONTENT to be set' + ) credentials = service_account.Credentials.from_service_account_file( # type: ignore[reportUnknownReturnType] service_account_path, scopes=['https://www.googleapis.com/auth/cloud-platform'], ) provider = GoogleProvider(credentials=credentials, project=project_id) - return GoogleModel('gemini-1.5-flash', provider=provider) + provider.client.aio._api_client._async_httpx_client = http_client # type: ignore + return GoogleModel('gemini-2.0-flash', provider=provider) def groq(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model: