1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import asyncio
1615import unittest
1716import unittest .mock
1817
19- from google .genai .models import Models , AsyncModels
18+ from google .genai .models import AsyncModels , Models
19+
2020from ..common .base import TestCase as CommonTestCaseBase
2121from .util import convert_to_response , create_response
2222
@@ -45,7 +45,7 @@ def mock_generate_content(self):
4545 if self ._generate_content_mock is None :
4646 self ._create_and_install_mocks ()
4747 return self ._generate_content_mock
48-
48+
4949 @property
5050 def mock_generate_content_stream (self ):
5151 if self ._generate_content_stream_mock is None :
@@ -68,66 +68,88 @@ def _create_and_install_mocks(self):
6868
6969 def _create_nonstream_mock (self ):
7070 mock = unittest .mock .MagicMock ()
71+
7172 def _default_impl (* args , ** kwargs ):
7273 if not self ._responses :
7374 return create_response (text = "Some response" )
7475 index = self ._response_index % len (self ._responses )
7576 result = self ._responses [index ]
7677 self ._response_index += 1
7778 return result
79+
7880 mock .side_effect = _default_impl
7981 return mock
8082
8183 def _create_stream_mock (self ):
8284 mock = unittest .mock .MagicMock ()
85+
8386 def _default_impl (* args , ** kwargs ):
8487 for response in self ._responses :
8588 yield response
89+
8690 mock .side_effect = _default_impl
8791 return mock
8892
8993 def _install_mocks (self ):
9094 output_wrapped = self ._wrap_output (self ._generate_content_mock )
91- output_wrapped_stream = self ._wrap_output_stream (self ._generate_content_stream_mock )
95+ output_wrapped_stream = self ._wrap_output_stream (
96+ self ._generate_content_stream_mock
97+ )
9298 Models .generate_content = output_wrapped
9399 Models .generate_content_stream = output_wrapped_stream
94100 AsyncModels .generate_content = self ._async_wrapper (output_wrapped )
95- AsyncModels .generate_content_stream = self ._async_stream_wrapper (output_wrapped_stream )
96-
101+ AsyncModels .generate_content_stream = self ._async_stream_wrapper (
102+ output_wrapped_stream
103+ )
104+
97105 def _wrap_output (self , mock_generate_content ):
98106 def _wrapped (* args , ** kwargs ):
99107 return convert_to_response (mock_generate_content (* args , ** kwargs ))
108+
100109 return _wrapped
101110
102111 def _wrap_output_stream (self , mock_generate_content_stream ):
103112 def _wrapped (* args , ** kwargs ):
104113 for output in mock_generate_content_stream (* args , ** kwargs ):
105- yield convert_to_response (output )
114+ yield convert_to_response (output )
115+
106116 return _wrapped
107117
108118 def _async_wrapper (self , mock_generate_content ):
109119 async def _wrapped (* args , ** kwargs ):
110120 return mock_generate_content (* args , ** kwargs )
121+
111122 return _wrapped
112123
113124 def _async_stream_wrapper (self , mock_generate_content_stream ):
114125 async def _wrapped (* args , ** kwargs ):
115126 async def _internal_generator ():
116127 for result in mock_generate_content_stream (* args , ** kwargs ):
117128 yield result
129+
118130 return _internal_generator ()
131+
119132 return _wrapped
120133
121134 def tearDown (self ):
122135 super ().tearDown ()
123136 if self ._generate_content_mock is None :
124137 assert Models .generate_content == self ._original_generate_content
125- assert Models .generate_content_stream == self ._original_generate_content_stream
126- assert AsyncModels .generate_content == self ._original_async_generate_content
127- assert AsyncModels .generate_content_stream == self ._original_async_generate_content_stream
138+ assert (
139+ Models .generate_content_stream
140+ == self ._original_generate_content_stream
141+ )
142+ assert (
143+ AsyncModels .generate_content
144+ == self ._original_async_generate_content
145+ )
146+ assert (
147+ AsyncModels .generate_content_stream
148+ == self ._original_async_generate_content_stream
149+ )
128150 Models .generate_content = self ._original_generate_content
129151 Models .generate_content_stream = self ._original_generate_content_stream
130152 AsyncModels .generate_content = self ._original_async_generate_content
131153 AsyncModels .generate_content_stream = (
132154 self ._original_async_generate_content_stream
133- )
155+ )
0 commit comments