Skip to content

Commit 7a7afb2

Browse files
committed
Test filters in chat
1 parent c344143 commit 7a7afb2

File tree

4 files changed

+68
-30
lines changed

4 files changed

+68
-30
lines changed

google/generativeai/discuss.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,9 @@ async def reply_async(
438438
request = _make_generate_message_request(**request)
439439
return await _generate_response_async(request=request, client=self._client)
440440

441+
def _convert_filters_to_enums(filters):
442+
for f in filters:
443+
f['reason'] = safety_types.BlockedReason(f['reason'])
441444

442445
def _build_chat_response(
443446
request: glm.GenerateMessageRequest,
@@ -453,6 +456,8 @@ def _build_chat_response(
453456
response = type(response).to_dict(response)
454457
response.pop("messages")
455458

459+
_convert_filters_to_enums(response['filters'])
460+
456461
if response["candidates"]:
457462
last = response["candidates"][0]
458463
else:

google/generativeai/types/safety_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"HarmCategory",
2323
"HarmProbability",
2424
"HarmBlockThreshold",
25-
"BlockReason",
25+
"BlockedReason",
2626
"ContentFilter",
2727
"SafetyRatingDict",
2828
"SafetySetting",
@@ -33,11 +33,11 @@
3333
HarmCategory = glm.HarmCategory
3434
HarmProbability = glm.SafetyRating.HarmProbability
3535
HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold
36-
BlockReason = glm.ContentFilter.BlockedReason
36+
BlockedReason = glm.ContentFilter.BlockedReason
3737

3838

3939
class ContentFilter(TypedDict):
40-
reason: BlockReason
40+
reason: BlockedReason
4141
message: str
4242

4343
__doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__)

tests/test_discuss.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
16+
1517
import unittest.mock
1618

1719
import google.ai.generativelanguage as glm
@@ -33,19 +35,22 @@ def setUp(self):
3335

3436
self.observed_request = None
3537

36-
def fake_generate_message(
37-
request: glm.GenerateMessageRequest,
38-
) -> glm.GenerateMessageResponse:
39-
self.observed_request = request
40-
return glm.GenerateMessageResponse(
41-
messages=request.prompt.messages,
38+
self.mock_response = glm.GenerateMessageResponse(
4239
candidates=[
4340
glm.Message(content="a", author="1"),
4441
glm.Message(content="b", author="1"),
4542
glm.Message(content="c", author="1"),
4643
],
4744
)
4845

46+
def fake_generate_message(
47+
request: glm.GenerateMessageRequest,
48+
) -> glm.GenerateMessageResponse:
49+
self.observed_request = request
50+
response = copy.copy(self.mock_response)
51+
response.messages = request.prompt.messages
52+
return response
53+
4954
self.client.generate_message = fake_generate_message
5055

5156
@parameterized.named_parameters(
@@ -268,6 +273,34 @@ def test_reply(self, kwargs):
268273

269274
response = response.reply("again")
270275

276+
def test_receive_and_reply_with_filters(self):
277+
278+
self.mock_response = mock_response = glm.GenerateMessageResponse(
279+
candidates=[glm.Message(content="a", author="1")],
280+
filters=[
281+
glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.SAFETY, message='unsafe'),
282+
glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.OTHER),]
283+
)
284+
response = discuss.chat(messages="do filters work?")
285+
286+
filters = response.filters
287+
self.assertLen(filters, 2)
288+
self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason)
289+
self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.SAFETY)
290+
self.assertEquals(filters[0]['message'], 'unsafe')
291+
292+
self.mock_response = glm.GenerateMessageResponse(
293+
candidates=[glm.Message(content="a", author="1")],
294+
filters=[
295+
glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED)]
296+
)
297+
298+
response = response.reply('Does reply work?')
299+
filters = response.filters
300+
self.assertLen(filters, 1)
301+
self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason)
302+
self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED)
303+
271304

272305
if __name__ == "__main__":
273306
absltest.main()

tests/test_discuss_async.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,28 @@ async def fake_generate_message(
6464
client=client,
6565
)
6666

67-
self.assertEqual(
68-
observed_request,
69-
glm.GenerateMessageRequest(
70-
model="models/bard",
71-
prompt=glm.MessagePrompt(
72-
context="Example Prompt",
73-
examples=[
74-
glm.Example(
75-
input=glm.Message(content="Example from human"),
76-
output=glm.Message(content="Example response from AI"),
77-
)
78-
],
79-
messages=[glm.Message(author="0", content="Tell me a joke")],
67+
self.assertEqual(
68+
observed_request,
69+
glm.GenerateMessageRequest(
70+
model="models/bard",
71+
prompt=glm.MessagePrompt(
72+
context="Example Prompt",
73+
examples=[
74+
glm.Example(
75+
input=glm.Message(content="Example from human"),
76+
output=glm.Message(content="Example response from AI"),
77+
)
78+
],
79+
messages=[glm.Message(author="0", content="Tell me a joke")],
80+
),
81+
temperature=0.75,
82+
candidate_count=1,
8083
),
81-
temperature=0.75,
82-
candidate_count=1,
83-
),
84-
)
85-
self.assertEqual(
86-
observed_response.candidates,
87-
[{"author": "1", "content": "Why did the chicken cross the road?"}],
88-
)
84+
)
85+
self.assertEqual(
86+
observed_response.candidates,
87+
[{"author": "1", "content": "Why did the chicken cross the road?"}],
88+
)
8989

9090

9191
if __name__ == "__main__":

0 commit comments

Comments
 (0)