1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from typing import Optional
16+
1517import os
1618import unittest
19+ import unittest .mock
1720
1821import google .genai
22+ import google .genai .types as genai_types
23+ from google .genai .models import Models , AsyncModels
1924
2025from .instrumentation_context import InstrumentationContext
2126from .otel_mocker import OTelMocker
@@ -28,6 +33,7 @@ def refresh(self, request):
2833
2934
3035class TestCase (unittest .TestCase ):
36+
3137 def setUp (self ):
3238 self ._otel = OTelMocker ()
3339 self ._otel .install ()
@@ -40,11 +46,31 @@ def setUp(self):
4046 self ._client = None
4147 self ._uses_vertex = False
4248 self ._credentials = _FakeCredentials ()
49+ self ._generate_content_mock = None
50+ self ._generate_content_stream_mock = None
51+ self ._original_generate_content = Models .generate_content
52+ self ._original_generate_content_stream = Models .generate_content_stream
53+ self ._original_async_generate_content = AsyncModels .generate_content
54+ self ._original_async_generate_content_stream = (
55+ AsyncModels .generate_content_stream
56+ )
4357
4458 def _lazy_init (self ):
4559 self ._instrumentation_context = InstrumentationContext ()
4660 self ._instrumentation_context .install ()
4761
62+ @property
63+ def mock_generate_content (self ):
64+ if self ._generate_content_mock is None :
65+ self ._create_mocks ()
66+ return self ._generate_content_mock
67+
68+ @property
69+ def mock_generate_content_stream (self ):
70+ if self ._generate_content_stream_mock is None :
71+ self ._create_mocks ()
72+ return self ._generate_content_stream_mock
73+
4874 @property
4975 def client (self ):
5076 if self ._client is None :
@@ -62,6 +88,81 @@ def otel(self):
6288 def set_use_vertex (self , use_vertex ):
6389 self ._uses_vertex = use_vertex
6490
91+ def generate_content_response (
92+ self ,
93+ part : Optional [genai_types .Part ] = None ,
94+ parts : Optional [list [genai_types .Part ]] = None ,
95+ content : Optional [genai_types .Content ] = None ,
96+ candidate : Optional [genai_types .Candidate ] = None ,
97+ candidates : Optional [list [genai_types .Candidate ]] = None ,
98+ text : Optional [str ] = None ):
99+ if text is None :
100+ text = 'Some response text'
101+ if part is None :
102+ part = genai_types .Part (text = text )
103+ if parts is None :
104+ parts = [part ]
105+ if content is None :
106+ content = genai_types .Content (parts = parts , role = 'model' )
107+ if candidate is None :
108+ candidate = genai_types .Candidate (content = content )
109+ if candidates is None :
110+ candidates = [candidate ]
111+ return genai_types .GenerateContentResponse (candidates = candidates )
112+
113+ def _create_mocks (self ):
114+ print ("Initializing mocks." )
115+ if self ._client is not None :
116+ self ._client = None
117+ if self ._instrumentation_context is not None :
118+ self ._instrumentation_context .uninstall ()
119+ self ._instrumentation_context = None
120+ self ._generate_content_mock = unittest .mock .MagicMock ()
121+ self ._generate_content_stream_mock = unittest .mock .MagicMock ()
122+
123+ def convert_response (arg ):
124+ if isinstance (arg , genai_types .GenerateContentResponse ):
125+ return arg
126+ if isinstance (arg , str ):
127+ return self .generate_content_response (text = arg )
128+ if isinstance (arg , dict ):
129+ try :
130+ return genai_types .GenerateContentResponse (** arg )
131+ except Exception :
132+ return self .generate_content_response (** arg )
133+ return arg
134+
135+ def default_stream (* args , ** kwargs ):
136+ result = self ._generate_content_mock (* args , ** kwargs )
137+ yield result
138+ self ._generate_content_stream_mock .side_effect = default_stream
139+
140+ def sync_variant (* args , ** kwargs ):
141+ return convert_response (self ._generate_content_mock (* args , ** kwargs ))
142+
143+ def sync_stream_variant (* args , ** kwargs ):
144+ print ("Calling sync stream variant" )
145+ for result in self ._generate_content_stream_mock (* args , ** kwargs ):
146+ yield convert_response (result )
147+
148+ async def async_variant (* args , ** kwargs ):
149+ print ("Calling async non-streaming variant" )
150+ return sync_variant (* args , ** kwargs )
151+
152+ async def async_stream_variant (* args , ** kwargs ):
153+ print ("Calling async stream variant" )
154+ async def gen ():
155+ for result in sync_stream_variant (* args , ** kwargs ):
156+ yield result
157+ class GeneratorProvider :
158+ def __aiter__ (self ):
159+ return gen ()
160+ return GeneratorProvider ()
161+ Models .generate_content = sync_variant
162+ Models .generate_content_stream = sync_stream_variant
163+ AsyncModels .generate_content = async_variant
164+ AsyncModels .generate_content_stream = async_stream_variant
165+
65166 def _create_client (self ):
66167 self ._lazy_init ()
67168 if self ._uses_vertex :
@@ -77,5 +178,16 @@ def _create_client(self):
77178 def tearDown (self ):
78179 if self ._instrumentation_context is not None :
79180 self ._instrumentation_context .uninstall ()
181+ if self ._generate_content_mock is None :
182+ assert Models .generate_content == self ._original_generate_content
183+ assert Models .generate_content_stream == self ._original_generate_content_stream
184+ assert AsyncModels .generate_content == self ._original_async_generate_content
185+ assert AsyncModels .generate_content_stream == self ._original_async_generate_content_stream
80186 self ._requests .uninstall ()
81187 self ._otel .uninstall ()
188+ Models .generate_content = self ._original_generate_content
189+ Models .generate_content_stream = self ._original_generate_content_stream
190+ AsyncModels .generate_content = self ._original_async_generate_content
191+ AsyncModels .generate_content_stream = (
192+ self ._original_async_generate_content_stream
193+ )
0 commit comments