@@ -167,7 +167,6 @@ def test_safety_settings(self):
167
167
],
168
168
)
169
169
170
- # Just make sure it made it into the request object.
171
170
self .assertEqual (
172
171
self .observed_request .safety_settings [0 ].category ,
173
172
safety_types .HarmCategory .HARM_CATEGORY_MEDICAL ,
@@ -243,13 +242,48 @@ def test_candidate_safety_feedback(self):
243
242
)
244
243
245
244
result = text_service .generate_text (prompt = "Write a story from the ER." )
246
- self .assertIsInstance (result .candidates [0 ]['safety_ratings' ][0 ]['category' ], safety_types .HarmCategory )
247
- self .assertEqual (result .candidates [0 ]['safety_ratings' ][0 ]['category' ], safety_types .HarmCategory .HARM_CATEGORY_MEDICAL )
245
+ self .assertIsInstance (
246
+ result .candidates [0 ]["safety_ratings" ][0 ]["category" ],
247
+ safety_types .HarmCategory ,
248
+ )
249
+ self .assertEqual (
250
+ result .candidates [0 ]["safety_ratings" ][0 ]["category" ],
251
+ safety_types .HarmCategory .HARM_CATEGORY_MEDICAL ,
252
+ )
248
253
249
- self .assertIsInstance (result .candidates [0 ]['safety_ratings' ][0 ]['probability' ], safety_types .HarmProbability )
250
- self .assertEqual (result .candidates [0 ]['safety_ratings' ][0 ]['probability' ], safety_types .HarmProbability .HIGH )
254
+ self .assertIsInstance (
255
+ result .candidates [0 ]["safety_ratings" ][0 ]["probability" ],
256
+ safety_types .HarmProbability ,
257
+ )
258
+ self .assertEqual (
259
+ result .candidates [0 ]["safety_ratings" ][0 ]["probability" ],
260
+ safety_types .HarmProbability .HIGH ,
261
+ )
251
262
252
- # def test_candidate_citations(self):
263
+ def test_candidate_citations (self ):
264
+ self .mock_response = glm .GenerateTextResponse (
265
+ candidates = [
266
+ {
267
+ "output" : "Hello Google!" ,
268
+ "citation_metadata" : {
269
+ "citation_sources" : [
270
+ {
271
+ "start_index" : 6 ,
272
+ "end_index" : 12 ,
273
+ "uri" : "https://google.com" ,
274
+ }
275
+ ]
276
+ },
277
+ }
278
+ ]
279
+ )
280
+ result = text_service .generate_text (prompt = "Hi my name is Google" )
281
+ self .assertEqual (
282
+ result .candidates [0 ]["citation_metadata" ]["citation_sources" ][0 ][
283
+ "start_index"
284
+ ],
285
+ 6 ,
286
+ )
253
287
254
288
255
289
if __name__ == "__main__" :
0 commit comments