Skip to content

Commit f6976dd

Browse files
committed
Test candidates['safety_ratings'].
1 parent e56fe19 commit f6976dd

File tree

5 files changed

+82
-23
lines changed

5 files changed

+82
-23
lines changed

google/generativeai/client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def configure(
4545
# but that seems rare. Users that need it can just switch to the low level API.
4646
transport: Union[str, None] = None,
4747
client_options: Union[client_options_lib.ClientOptions, dict, None] = None,
48-
client_info: Optional[gapic_v1.client_info.ClientInfo] = None
48+
client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
4949
):
5050
"""Captures default client configuration.
5151
@@ -86,13 +86,13 @@ def configure(
8686

8787
user_agent = f"{USER_AGENT}/{version.__version__}"
8888
if client_info:
89-
# Be respectful of any existing agent setting.
90-
if client_info.user_agent:
91-
client_info.user_agent += f" {user_agent}"
92-
else:
93-
client_info.user_agent = user_agent
89+
# Be respectful of any existing agent setting.
90+
if client_info.user_agent:
91+
client_info.user_agent += f" {user_agent}"
92+
else:
93+
client_info.user_agent = user_agent
9494
else:
95-
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)
95+
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)
9696

9797
new_default_client_config = {
9898
"credentials": credentials,

google/generativeai/discuss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _build_chat_response(
455455
response = type(response).to_dict(response)
456456
response.pop("messages")
457457

458-
safety_types.convert_filters_to_enums(response["filters"])
458+
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
459459

460460
if response["candidates"]:
461461
last = response["candidates"][0]

google/generativeai/text.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,11 @@ def _generate_response(
159159
response = client.generate_text(request)
160160
response = type(response).to_dict(response)
161161

162-
safety_types.convert_filters_to_enums(response["filters"])
163-
safety_types.convert_safety_feedback_to_enums(response["safety_feedback"])
162+
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
163+
response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums(
164+
response["safety_feedback"]
165+
)
166+
response['candidates'] = safety_types.convert_candidate_enums(response['candidates'])
164167

165168
return Completion(_client=client, **response)
166169

google/generativeai/types/safety_types.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import enum
1717
from google.ai import generativelanguage as glm
1818
from google.generativeai import docstring_utils
19-
from typing import TypedDict
19+
from typing import Iterable, List, TypedDict
2020

2121
__all__ = [
2222
"HarmCategory",
@@ -43,9 +43,13 @@ class ContentFilterDict(TypedDict):
4343
__doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__)
4444

4545

46-
def convert_filters_to_enums(filters):
46+
def convert_filters_to_enums(filters: Iterable[dict]) -> List[ContentFilterDict]:
47+
result = []
4748
for f in filters:
49+
f = f.copy()
4850
f["reason"] = BlockedReason(f["reason"])
51+
result.append(f)
52+
return result
4953

5054

5155
class SafetyRatingDict(TypedDict):
@@ -55,9 +59,18 @@ class SafetyRatingDict(TypedDict):
5559
__doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__)
5660

5761

58-
def convert_rating_to_enum(setting):
59-
setting["category"] = HarmCategory(setting["category"])
60-
setting["probability"] = HarmProbability(setting["probability"])
62+
def convert_rating_to_enum(rating: dict) -> SafetyRatingDict:
63+
return {
64+
"category": HarmCategory(rating["category"]),
65+
"probability": HarmProbability(rating["probability"]),
66+
}
67+
68+
69+
def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]:
70+
result = []
71+
for r in ratings:
72+
result.append(convert_rating_to_enum(r))
73+
return result
6174

6275

6376
class SafetySettingDict(TypedDict):
@@ -67,9 +80,11 @@ class SafetySettingDict(TypedDict):
6780
__doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__)
6881

6982

70-
def convert_setting_to_enum(setting):
71-
setting["category"] = HarmCategory(setting["category"])
72-
setting["threshold"] = HarmBlockThreshold(setting["threshold"])
83+
def convert_setting_to_enum(setting: dict) -> SafetySettingDict:
84+
return {
85+
"category": HarmCategory(setting["category"]),
86+
"threshold": HarmBlockThreshold(setting["threshold"]),
87+
}
7388

7489

7590
class SafetyFeedbackDict(TypedDict):
@@ -79,7 +94,24 @@ class SafetyFeedbackDict(TypedDict):
7994
__doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__)
8095

8196

82-
def convert_safety_feedback_to_enums(safety_feedback):
97+
def convert_safety_feedback_to_enums(
98+
safety_feedback: Iterable[dict],
99+
) -> List[SafetyFeedbackDict]:
100+
result = []
83101
for sf in safety_feedback:
84-
convert_rating_to_enum(sf["rating"])
85-
convert_setting_to_enum(sf["setting"])
102+
result.append(
103+
{
104+
"rating": convert_rating_to_enum(sf["rating"]),
105+
"setting": convert_setting_to_enum(sf["setting"]),
106+
}
107+
)
108+
return result
109+
110+
111+
def convert_candidate_enums(candidates):
112+
result = []
113+
for candidate in candidates:
114+
candidate = candidate.copy()
115+
candidate['safety_ratings'] = convert_ratings_to_enum(candidate['safety_ratings'])
116+
result.append(candidate)
117+
return result

tests/test_text.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,33 @@ def test_safety_feedback(self):
223223
safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
224224
)
225225

226-
#def test_candidate_safety_feedback(self):
226+
def test_candidate_safety_feedback(self):
227+
self.mock_response = glm.GenerateTextResponse(
228+
candidates=[
229+
{
230+
"output": "hello",
231+
"safety_ratings": [
232+
{
233+
"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
234+
"probability": safety_types.HarmProbability.HIGH,
235+
},
236+
{
237+
"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE,
238+
"probability": safety_types.HarmProbability.LOW,
239+
},
240+
],
241+
}
242+
]
243+
)
244+
245+
result = text_service.generate_text(prompt="Write a story from the ER.")
246+
self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory)
247+
self.assertEqual(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory.HARM_CATEGORY_MEDICAL)
248+
249+
self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability)
250+
self.assertEqual(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability.HIGH)
227251

228-
#def test_candidate_citations(self):
252+
# def test_candidate_citations(self):
229253

230254

231255
if __name__ == "__main__":

0 commit comments

Comments
 (0)