12
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
-
16
- import os
15
+ import copy
17
16
import unittest
18
17
import unittest .mock as mock
19
18
22
21
from google .generativeai import text as text_service
23
22
from google .generativeai import client
24
23
from google .generativeai .types import safety_types
24
+ from google .generativeai .types import model_types
25
25
from absl .testing import absltest
26
26
from absl .testing import parameterized
27
27
@@ -31,8 +31,9 @@ def setUp(self):
31
31
self .client = unittest .mock .MagicMock ()
32
32
33
33
client ._client_manager .text_client = self .client
34
+ client ._client_manager .model_client = self .client
34
35
35
- self .observed_request = None
36
+ self .observed_requests = []
36
37
37
38
self .responses = {}
38
39
@@ -45,23 +46,37 @@ def add_client_method(f):
45
46
def generate_text (
46
47
request : glm .GenerateTextRequest ,
47
48
) -> glm .GenerateTextResponse :
48
- self .observed_request = request
49
+ self .observed_requests . append ( request )
49
50
return self .responses ["generate_text" ]
50
51
51
52
@add_client_method
52
53
def embed_text (
53
54
request : glm .EmbedTextRequest ,
54
55
) -> glm .EmbedTextResponse :
55
- self .observed_request = request
56
+ self .observed_requests . append ( request )
56
57
return self .responses ["embed_text" ]
57
58
58
59
@add_client_method
59
60
def batch_embed_text (
60
61
request : glm .EmbedTextRequest ,
61
62
) -> glm .EmbedTextResponse :
62
- self .observed_request = request
63
+ self .observed_requests . append ( request )
63
64
return self .responses ["batch_embed_text" ]
64
65
66
+ @add_client_method
67
+ def count_text_tokens (
68
+ request : glm .CountTextTokensRequest ,
69
+ ) -> glm .CountTextTokensResponse :
70
+ self .observed_requests .append (request )
71
+ return self .responses ["count_text_tokens" ]
72
+
73
+ @add_client_method
74
+ def get_tuned_model (name ) -> glm .TunedModel :
75
+ request = glm .GetTunedModelRequest (name = name )
76
+ self .observed_requests .append (request )
77
+ response = copy .copy (self .responses ["get_tuned_model" ])
78
+ return response
79
+
65
80
@parameterized .named_parameters (
66
81
[
67
82
dict (testcase_name = "string" , prompt = "Hello how are" ),
@@ -99,7 +114,7 @@ def test_generate_embeddings(self, model, text):
99
114
emb = text_service .generate_embeddings (model = model , text = text )
100
115
101
116
self .assertIsInstance (emb , dict )
102
- self .assertEqual (self .observed_request , glm .EmbedTextRequest (model = model , text = text ))
117
+ self .assertEqual (self .observed_requests [ - 1 ] , glm .EmbedTextRequest (model = model , text = text ))
103
118
self .assertIsInstance (emb ["embedding" ][0 ], float )
104
119
105
120
@parameterized .named_parameters (
@@ -123,8 +138,7 @@ def test_generate_embeddings_batch(self, model, text):
123
138
124
139
self .assertIsInstance (emb , dict )
125
140
self .assertEqual (
126
- self .observed_request ,
127
- glm .BatchEmbedTextRequest (model = model , texts = text ),
141
+ self .observed_requests [- 1 ], glm .BatchEmbedTextRequest (model = model , texts = text )
128
142
)
129
143
self .assertIsInstance (emb ["embedding" ][0 ], list )
130
144
@@ -160,7 +174,7 @@ def test_generate_response(self, *, prompt, **kwargs):
160
174
complete = text_service .generate_text (prompt = prompt , ** kwargs )
161
175
162
176
self .assertEqual (
163
- self .observed_request ,
177
+ self .observed_requests [ - 1 ] ,
164
178
glm .GenerateTextRequest (
165
179
model = "models/text-bison-001" , prompt = glm .TextPrompt (text = prompt ), ** kwargs
166
180
),
@@ -188,15 +202,15 @@ def test_stop_string(self):
188
202
complete = text_service .generate_text (prompt = "Hello" , stop_sequences = "stop" )
189
203
190
204
self .assertEqual (
191
- self .observed_request ,
205
+ self .observed_requests [ - 1 ] ,
192
206
glm .GenerateTextRequest (
193
207
model = "models/text-bison-001" ,
194
208
prompt = glm .TextPrompt (text = "Hello" ),
195
209
stop_sequences = ["stop" ],
196
210
),
197
211
)
198
212
# Just make sure it made it into the request object.
199
- self .assertEqual (self .observed_request .stop_sequences , ["stop" ])
213
+ self .assertEqual (self .observed_requests [ - 1 ] .stop_sequences , ["stop" ])
200
214
201
215
@parameterized .named_parameters (
202
216
[
@@ -251,7 +265,7 @@ def test_safety_settings(self, safety_settings):
251
265
)
252
266
253
267
self .assertEqual (
254
- self .observed_request .safety_settings [0 ].category ,
268
+ self .observed_requests [ - 1 ] .safety_settings [0 ].category ,
255
269
safety_types .HarmCategory .HARM_CATEGORY_MEDICAL ,
256
270
)
257
271
@@ -367,6 +381,72 @@ def test_candidate_citations(self):
367
381
6 ,
368
382
)
369
383
384
+ @parameterized .named_parameters (
385
+ [
386
+ dict (testcase_name = "base-name" , model = "models/text-bison-001" ),
387
+ dict (testcase_name = "tuned-name" , model = "tunedModels/bipedal-pangolin-001" ),
388
+ dict (
389
+ testcase_name = "model" ,
390
+ model = model_types .Model (
391
+ name = "models/text-bison-001" ,
392
+ base_model_id = "text-bison-001" ,
393
+ version = "001" ,
394
+ display_name = "🦬" ,
395
+ description = "🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬" ,
396
+ input_token_limit = 8000 ,
397
+ output_token_limit = 4000 ,
398
+ supported_generation_methods = ["GenerateText" ],
399
+ ),
400
+ ),
401
+ dict (
402
+ testcase_name = "tuned_model" ,
403
+ model = model_types .TunedModel (
404
+ name = "tunedModels/bipedal-pangolin-001" ,
405
+ base_model = "models/text-bison-001" ,
406
+ ),
407
+ ),
408
+ dict (
409
+ testcase_name = "glm_model" ,
410
+ model = glm .Model (
411
+ name = "models/text-bison-001" ,
412
+ ),
413
+ ),
414
+ dict (
415
+ testcase_name = "glm_tuned_model" ,
416
+ model = glm .TunedModel (
417
+ name = "tunedModels/bipedal-pangolin-001" ,
418
+ base_model = "models/text-bison-001" ,
419
+ ),
420
+ ),
421
+ dict (
422
+ testcase_name = "glm_tuned_model_nested" ,
423
+ model = glm .TunedModel (
424
+ name = "tunedModels/bipedal-pangolin-002" ,
425
+ tuned_model_source = {
426
+ "tuned_model" : "tunedModels/bipedal-pangolin-002" ,
427
+ "base_model" : "models/text-bison-001" ,
428
+ },
429
+ ),
430
+ ),
431
+ ]
432
+ )
433
+ def test_count_message_tokens (self , model ):
434
+ self .responses ["get_tuned_model" ] = glm .TunedModel (
435
+ name = "tunedModels/bipedal-pangolin-001" , base_model = "models/text-bison-001"
436
+ )
437
+ self .responses ["count_text_tokens" ] = glm .CountTextTokensResponse (token_count = 7 )
438
+
439
+ response = text_service .count_text_tokens (model , "Tell me a story about a magic backpack." )
440
+ self .assertEqual ({"token_count" : 7 }, response )
441
+
442
+ should_look_up_model = isinstance (model , str ) and model .startswith ("tunedModels/" )
443
+ if should_look_up_model :
444
+ self .assertLen (self .observed_requests , 2 )
445
+ self .assertEqual (
446
+ self .observed_requests [0 ],
447
+ glm .GetTunedModelRequest (name = "tunedModels/bipedal-pangolin-001" ),
448
+ )
449
+
370
450
371
451
if __name__ == "__main__" :
372
452
absltest .main ()
0 commit comments