Skip to content

Commit e81bfc1

Browse files
committed
Test text candidate citations and safety ratings.
1 parent f6976dd commit e81bfc1

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

google/generativeai/text.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def _generate_response(
163163
response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums(
164164
response["safety_feedback"]
165165
)
166-
response['candidates'] = safety_types.convert_candidate_enums(response['candidates'])
166+
response["candidates"] = safety_types.convert_candidate_enums(
167+
response["candidates"]
168+
)
167169

168170
return Completion(_client=client, **response)
169171

google/generativeai/types/safety_types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def convert_candidate_enums(candidates):
112112
result = []
113113
for candidate in candidates:
114114
candidate = candidate.copy()
115-
candidate['safety_ratings'] = convert_ratings_to_enum(candidate['safety_ratings'])
115+
candidate["safety_ratings"] = convert_ratings_to_enum(
116+
candidate["safety_ratings"]
117+
)
116118
result.append(candidate)
117-
return result
119+
return result

tests/test_text.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def test_safety_settings(self):
167167
],
168168
)
169169

170-
# Just make sure it made it into the request object.
171170
self.assertEqual(
172171
self.observed_request.safety_settings[0].category,
173172
safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
@@ -243,13 +242,48 @@ def test_candidate_safety_feedback(self):
243242
)
244243

245244
result = text_service.generate_text(prompt="Write a story from the ER.")
246-
self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory)
247-
self.assertEqual(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory.HARM_CATEGORY_MEDICAL)
245+
self.assertIsInstance(
246+
result.candidates[0]["safety_ratings"][0]["category"],
247+
safety_types.HarmCategory,
248+
)
249+
self.assertEqual(
250+
result.candidates[0]["safety_ratings"][0]["category"],
251+
safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
252+
)
248253

249-
self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability)
250-
self.assertEqual(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability.HIGH)
254+
self.assertIsInstance(
255+
result.candidates[0]["safety_ratings"][0]["probability"],
256+
safety_types.HarmProbability,
257+
)
258+
self.assertEqual(
259+
result.candidates[0]["safety_ratings"][0]["probability"],
260+
safety_types.HarmProbability.HIGH,
261+
)
251262

252-
# def test_candidate_citations(self):
263+
def test_candidate_citations(self):
264+
self.mock_response = glm.GenerateTextResponse(
265+
candidates=[
266+
{
267+
"output": "Hello Google!",
268+
"citation_metadata": {
269+
"citation_sources": [
270+
{
271+
"start_index": 6,
272+
"end_index": 12,
273+
"uri": "https://google.com",
274+
}
275+
]
276+
},
277+
}
278+
]
279+
)
280+
result = text_service.generate_text(prompt="Hi my name is Google")
281+
self.assertEqual(
282+
result.candidates[0]["citation_metadata"]["citation_sources"][0][
283+
"start_index"
284+
],
285+
6,
286+
)
253287

254288

255289
if __name__ == "__main__":

0 commit comments

Comments
 (0)