@@ -795,6 +795,19 @@ async def mock_create(*args, **kwargs):
795795 assert result .content == "System-aware response"
796796
797797
798+ class AsyncContextManagerMock :
799+ """Helper class for mocking async context managers"""
800+
801+ def __init__ (self , return_value ):
802+ self ._return_value = return_value
803+
804+ async def __aenter__ (self ):
805+ return self ._return_value
806+
807+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
808+ return None
809+
810+
798811class TestLocalProviderGenerate :
799812 """Test Local provider generate method"""
800813
@@ -814,23 +827,14 @@ async def test_local_provider_generate_basic(self):
814827 "prompt_eval_count" : 100 ,
815828 }
816829
817- async def mock_post (* args , ** kwargs ):
818- mock_resp = MagicMock ()
819-
820- async def mock_json ():
821- return mock_response_data
822-
823- mock_resp .json = mock_json
824- mock_resp .__aenter__ = AsyncMock (return_value = mock_resp )
825- mock_resp .__aexit__ = AsyncMock (return_value = None )
826- return mock_resp
830+ mock_resp = MagicMock ()
831+ mock_resp .json = AsyncMock (return_value = mock_response_data )
827832
828833 mock_session = MagicMock ()
829- mock_session .post = mock_post
830- mock_session .__aenter__ = AsyncMock (return_value = mock_session )
831- mock_session .__aexit__ = AsyncMock (return_value = None )
834+ mock_session .post = MagicMock (return_value = AsyncContextManagerMock (mock_resp ))
832835
833- with patch ("aiohttp.ClientSession" , return_value = mock_session ):
836+ with patch ("aiohttp.ClientSession" ) as mock_client :
837+ mock_client .return_value = AsyncContextManagerMock (mock_session )
834838 result = await provider .generate (messages )
835839
836840 assert result .content == "Local response"
@@ -854,27 +858,21 @@ async def test_local_provider_generate_with_system_prompt(self):
854858 "prompt_eval_count" : 100 ,
855859 }
856860
857- async def mock_post (* args , ** kwargs ):
861+ mock_resp = MagicMock ()
862+ mock_resp .json = AsyncMock (return_value = mock_response_data )
863+
864+ mock_session = MagicMock ()
865+
866+ def mock_post (* args , ** kwargs ):
858867 # Verify system prompt in payload
859868 assert "system" in kwargs ["json" ]
860869 assert kwargs ["json" ]["system" ] == "You are helpful"
870+ return AsyncContextManagerMock (mock_resp )
861871
862- mock_resp = MagicMock ()
863-
864- async def mock_json ():
865- return mock_response_data
866-
867- mock_resp .json = mock_json
868- mock_resp .__aenter__ = AsyncMock (return_value = mock_resp )
869- mock_resp .__aexit__ = AsyncMock (return_value = None )
870- return mock_resp
871-
872- mock_session = MagicMock ()
873872 mock_session .post = mock_post
874- mock_session .__aenter__ = AsyncMock (return_value = mock_session )
875- mock_session .__aexit__ = AsyncMock (return_value = None )
876873
877- with patch ("aiohttp.ClientSession" , return_value = mock_session ):
874+ with patch ("aiohttp.ClientSession" ) as mock_client :
875+ mock_client .return_value = AsyncContextManagerMock (mock_session )
878876 result = await provider .generate (messages , system_prompt = "You are helpful" )
879877
880878 assert result .content == "System-aware local response"
0 commit comments