Skip to content

Commit 1e9984e

Browse files
committed
Discuss: handle case where all messages are filtered
1 parent 30a1ebf commit 1e9984e

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

google/generativeai/discuss.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,11 @@ def __init__(self, **kwargs):
390390

391391
@property
392392
@set_doc(discuss_types.ChatResponse.last.__doc__)
393-
def last(self) -> str:
394-
return self.messages[-1]["content"]
393+
def last(self) -> Optional[str]:
394+
if self.messages[-1]:
395+
return self.messages[-1]["content"]
396+
else:
397+
return None
395398

396399
@last.setter
397400
def last(self, message: discuss_types.MessageOptions):
@@ -406,9 +409,14 @@ def reply(
406409
raise TypeError(
407410
f"reply can't be called on an async client, use reply_async instead."
408411
)
412+
if self.last is None:
413+
raise ValueError('The last response from the model did not return any candidates.\n'
414+
'Check the `.filters` attribute to see why the responses were filtered:\n'
415+
f'{self.filters}')
416+
409417
request = self.to_dict()
410418
request.pop("candidates")
411-
request.pop("filters")
419+
request.pop("filters", None)
412420
request["messages"] = list(request["messages"])
413421
request["messages"].append(_make_message(message))
414422
request = _make_generate_message_request(**request)
@@ -424,6 +432,7 @@ async def reply_async(
424432
)
425433
request = self.to_dict()
426434
request.pop("candidates")
435+
request.pop("filters")
427436
request["messages"] = list(request["messages"])
428437
request["messages"].append(_make_message(message))
429438
request = _make_generate_message_request(**request)
@@ -444,7 +453,11 @@ def _build_chat_response(
444453
response = type(response).to_dict(response)
445454
response.pop("messages")
446455

447-
request["messages"].append(response["candidates"][0])
456+
if response["candidates"]:
457+
last = response["candidates"][0]
458+
else:
459+
last = None
460+
request["messages"].append(last)
448461
request.setdefault("temperature", None)
449462
request.setdefault("candidate_count", None)
450463

tests/test_discuss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,5 +268,6 @@ def test_reply(self, kwargs):
268268

269269
response = response.reply("again")
270270

271+
271272
if __name__ == "__main__":
272273
absltest.main()

0 commit comments

Comments
 (0)