12
12
from google .generativeai import generative_models
13
13
from google .generativeai .types import content_types
14
14
from google .generativeai .types import generation_types
15
+ from google .generativeai .types import helper_types
16
+
15
17
16
18
import PIL .Image
17
19
@@ -37,49 +39,63 @@ def simple_response(text: str) -> glm.GenerateContentResponse:
37
39
return glm .GenerateContentResponse ({"candidates" : [{"content" : simple_part (text )}]})
38
40
39
41
42
+ class MockGenerativeServiceClient :
43
+ def __init__ (self , test ):
44
+ self .test = test
45
+ self .observed_requests = []
46
+ self .observed_kwargs = []
47
+ self .responses = collections .defaultdict (list )
48
+
49
+ def generate_content (
50
+ self ,
51
+ request : glm .GenerateContentRequest ,
52
+ ** kwargs ,
53
+ ) -> glm .GenerateContentResponse :
54
+ self .test .assertIsInstance (request , glm .GenerateContentRequest )
55
+ self .observed_requests .append (request )
56
+ self .observed_kwargs .append (kwargs )
57
+ response = self .responses ["generate_content" ].pop (0 )
58
+ return response
59
+
60
+ def stream_generate_content (
61
+ self ,
62
+ request : glm .GetModelRequest ,
63
+ ** kwargs ,
64
+ ) -> Iterable [glm .GenerateContentResponse ]:
65
+ self .observed_requests .append (request )
66
+ self .observed_kwargs .append (kwargs )
67
+ response = self .responses ["stream_generate_content" ].pop (0 )
68
+ return response
69
+
70
+ def count_tokens (
71
+ self ,
72
+ request : glm .CountTokensRequest ,
73
+ ** kwargs ,
74
+ ) -> Iterable [glm .GenerateContentResponse ]:
75
+ self .observed_requests .append (request )
76
+ self .observed_kwargs .append (kwargs )
77
+ response = self .responses ["count_tokens" ].pop (0 )
78
+ return response
79
+
80
+
40
81
class CUJTests (parameterized .TestCase ):
41
82
"""Tests are in order with the design doc."""
42
83
43
- def setUp (self ):
44
- self .client = unittest .mock .MagicMock ()
84
+ @property
85
+ def observed_requests (self ):
86
+ return self .client .observed_requests
45
87
46
- client_lib ._client_manager .clients ["generative" ] = self .client
47
-
48
- def add_client_method (f ):
49
- name = f .__name__
50
- setattr (self .client , name , f )
51
- return f
88
+ @property
89
+ def observed_kwargs (self ):
90
+ return self .client .observed_kwargs
52
91
53
- self .observed_requests = []
54
- self .responses = collections .defaultdict (list )
92
+ @property
93
+ def responses (self ):
94
+ return self .client .responses
55
95
56
- @add_client_method
57
- def generate_content (
58
- request : glm .GenerateContentRequest ,
59
- ** kwargs ,
60
- ) -> glm .GenerateContentResponse :
61
- self .assertIsInstance (request , glm .GenerateContentRequest )
62
- self .observed_requests .append (request )
63
- response = self .responses ["generate_content" ].pop (0 )
64
- return response
65
-
66
- @add_client_method
67
- def stream_generate_content (
68
- request : glm .GetModelRequest ,
69
- ** kwargs ,
70
- ) -> Iterable [glm .GenerateContentResponse ]:
71
- self .observed_requests .append (request )
72
- response = self .responses ["stream_generate_content" ].pop (0 )
73
- return response
74
-
75
- @add_client_method
76
- def count_tokens (
77
- request : glm .CountTokensRequest ,
78
- ** kwargs ,
79
- ) -> Iterable [glm .GenerateContentResponse ]:
80
- self .observed_requests .append (request )
81
- response = self .responses ["count_tokens" ].pop (0 )
82
- return response
96
+ def setUp (self ):
97
+ self .client = MockGenerativeServiceClient (self )
98
+ client_lib ._client_manager .clients ["generative" ] = self .client
83
99
84
100
def test_hello (self ):
85
101
# Generate text from text prompt
@@ -451,7 +467,7 @@ def test_copy_history(self):
451
467
chat1 = model .start_chat ()
452
468
chat1 .send_message ("hello1" )
453
469
454
- chat2 = copy .deepcopy (chat1 )
470
+ chat2 = copy .copy (chat1 )
455
471
chat2 .send_message ("hello2" )
456
472
457
473
chat1 .send_message ("hello3" )
@@ -810,7 +826,7 @@ def test_async_code_match(self, obj, aobj):
810
826
)
811
827
812
828
asource = re .sub (" *?# type: ignore" , "" , asource )
813
- self .assertEqual (source , asource )
829
+ self .assertEqual (source , asource , f"error in { obj = } " )
814
830
815
831
def test_repr_for_unary_non_streamed_response (self ):
816
832
model = generative_models .GenerativeModel (model_name = "gemini-pro" )
@@ -1208,15 +1224,30 @@ def test_repr_for_system_instruction(self):
1208
1224
self .assertIn ("system_instruction='Be excellent.'" , result )
1209
1225
1210
1226
def test_count_tokens_called_with_request_options (self ):
1211
- self .client .count_tokens = unittest .mock .MagicMock ()
1212
- request = unittest .mock .ANY
1227
+ self .responses ["count_tokens" ].append (glm .CountTokensResponse ())
1213
1228
request_options = {"timeout" : 120 }
1214
1229
1215
- self .responses ["count_tokens" ] = [glm .CountTokensResponse (total_tokens = 7 )]
1216
1230
model = generative_models .GenerativeModel ("gemini-pro-vision" )
1217
1231
model .count_tokens ([{"role" : "user" , "parts" : ["hello" ]}], request_options = request_options )
1218
1232
1219
- self .client .count_tokens .assert_called_once_with (request , ** request_options )
1233
+ self .assertEqual (request_options , self .observed_kwargs [0 ])
1234
+
1235
+ def test_chat_with_request_options (self ):
1236
+ self .responses ["generate_content" ].append (
1237
+ glm .GenerateContentResponse (
1238
+ {
1239
+ "candidates" : [{"finish_reason" : "STOP" }],
1240
+ }
1241
+ )
1242
+ )
1243
+ request_options = {"timeout" : 120 }
1244
+
1245
+ model = generative_models .GenerativeModel ("gemini-pro" )
1246
+ chat = model .start_chat ()
1247
+ chat .send_message ("hello" , request_options = helper_types .RequestOptions (** request_options ))
1248
+
1249
+ request_options ["retry" ] = None
1250
+ self .assertEqual (request_options , self .observed_kwargs [0 ])
1220
1251
1221
1252
1222
1253
if __name__ == "__main__" :
0 commit comments