Skip to content

Commit 3d08e92

Browse files
GeneAIclaude
authored andcommitted
fix: Fix LocalProvider async context manager mocking in tests
Added AsyncContextManagerMock helper class to properly mock aiohttp.ClientSession's async with protocol. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent be2cbf1 commit 3d08e92

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

tests/test_providers.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,19 @@ async def mock_create(*args, **kwargs):
795795
assert result.content == "System-aware response"
796796

797797

798+
class AsyncContextManagerMock:
799+
"""Helper class for mocking async context managers"""
800+
801+
def __init__(self, return_value):
802+
self._return_value = return_value
803+
804+
async def __aenter__(self):
805+
return self._return_value
806+
807+
async def __aexit__(self, exc_type, exc_val, exc_tb):
808+
return None
809+
810+
798811
class TestLocalProviderGenerate:
799812
"""Test Local provider generate method"""
800813

@@ -814,23 +827,14 @@ async def test_local_provider_generate_basic(self):
814827
"prompt_eval_count": 100,
815828
}
816829

817-
async def mock_post(*args, **kwargs):
818-
mock_resp = MagicMock()
819-
820-
async def mock_json():
821-
return mock_response_data
822-
823-
mock_resp.json = mock_json
824-
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
825-
mock_resp.__aexit__ = AsyncMock(return_value=None)
826-
return mock_resp
830+
mock_resp = MagicMock()
831+
mock_resp.json = AsyncMock(return_value=mock_response_data)
827832

828833
mock_session = MagicMock()
829-
mock_session.post = mock_post
830-
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
831-
mock_session.__aexit__ = AsyncMock(return_value=None)
834+
mock_session.post = MagicMock(return_value=AsyncContextManagerMock(mock_resp))
832835

833-
with patch("aiohttp.ClientSession", return_value=mock_session):
836+
with patch("aiohttp.ClientSession") as mock_client:
837+
mock_client.return_value = AsyncContextManagerMock(mock_session)
834838
result = await provider.generate(messages)
835839

836840
assert result.content == "Local response"
@@ -854,27 +858,21 @@ async def test_local_provider_generate_with_system_prompt(self):
854858
"prompt_eval_count": 100,
855859
}
856860

857-
async def mock_post(*args, **kwargs):
861+
mock_resp = MagicMock()
862+
mock_resp.json = AsyncMock(return_value=mock_response_data)
863+
864+
mock_session = MagicMock()
865+
866+
def mock_post(*args, **kwargs):
858867
# Verify system prompt in payload
859868
assert "system" in kwargs["json"]
860869
assert kwargs["json"]["system"] == "You are helpful"
870+
return AsyncContextManagerMock(mock_resp)
861871

862-
mock_resp = MagicMock()
863-
864-
async def mock_json():
865-
return mock_response_data
866-
867-
mock_resp.json = mock_json
868-
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
869-
mock_resp.__aexit__ = AsyncMock(return_value=None)
870-
return mock_resp
871-
872-
mock_session = MagicMock()
873872
mock_session.post = mock_post
874-
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
875-
mock_session.__aexit__ = AsyncMock(return_value=None)
876873

877-
with patch("aiohttp.ClientSession", return_value=mock_session):
874+
with patch("aiohttp.ClientSession") as mock_client:
875+
mock_client.return_value = AsyncContextManagerMock(mock_session)
878876
result = await provider.generate(messages, system_prompt="You are helpful")
879877

880878
assert result.content == "System-aware local response"

0 commit comments

Comments
 (0)