|
30 | 30 | _DEFAULT_API_ENDPOINT = "generativelanguage.googleapis.com" |
31 | 31 | _USER_AGENT = f"langchain/{langchain_core.__version__}" |
32 | 32 | _DEFAULT_PAGE_SIZE = 20 |
33 | | -_DEFAULT_GENERATE_SERVICE_MODEL = "models/aqa" |
34 | 33 | _MAX_REQUEST_PER_CHUNK = 100 |
35 | 34 | _NAME_REGEX = re.compile(r"^corpora/([^/]+?)(/documents/([^/]+?)(/chunks/([^/]+?))?)?$") |
36 | 35 |
|
@@ -230,6 +229,7 @@ def _get_credentials() -> credentials.Credentials | None: |
230 | 229 |
|
231 | 230 |
|
232 | 231 | def build_semantic_retriever() -> genai.RetrieverServiceClient: |
| 232 | + """Uses the default `'grpc'` transport to build a semantic retriever client.""" |
233 | 233 | credentials = _get_credentials() |
234 | 234 | return genai.RetrieverServiceClient( |
235 | 235 | credentials=credentials, |
@@ -598,128 +598,6 @@ def query_document( |
598 | 598 | return list(response.relevant_chunks) |
599 | 599 |
|
600 | 600 |
|
601 | | -@dataclass |
602 | | -class Passage: |
603 | | - text: str |
604 | | - id: str |
605 | | - |
606 | | - |
607 | | -@dataclass |
608 | | -class GroundedAnswer: |
609 | | - answer: str |
610 | | - attributed_passages: list[Passage] |
611 | | - answerable_probability: float | None |
612 | | - |
613 | | - |
614 | | -@dataclass |
615 | | -class GenerateAnswerError(Exception): |
616 | | - finish_reason: genai.Candidate.FinishReason |
617 | | - finish_message: str |
618 | | - safety_ratings: MutableSequence[genai.SafetyRating] |
619 | | - |
620 | | - def __str__(self) -> str: |
621 | | - return ( |
622 | | - f"finish_reason: {self.finish_reason} " |
623 | | - f"finish_message: {self.finish_message} " |
624 | | - f"safety ratings: {self.safety_ratings}" |
625 | | - ) |
626 | | - |
627 | | - |
628 | | -def generate_answer( |
629 | | - *, |
630 | | - prompt: str, |
631 | | - passages: list[str], |
632 | | - answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, |
633 | | - safety_settings: list[genai.SafetySetting] | None = None, |
634 | | - temperature: float | None = None, |
635 | | - client: genai.GenerativeServiceClient, |
636 | | -) -> GroundedAnswer: |
637 | | - # TODO: Consider passing in the corpus ID instead of the actual |
638 | | - # passages. |
639 | | - if safety_settings is None: |
640 | | - safety_settings = [] |
641 | | - response = client.generate_answer( |
642 | | - genai.GenerateAnswerRequest( |
643 | | - contents=[ |
644 | | - genai.Content(parts=[genai.Part(text=prompt)]), |
645 | | - ], |
646 | | - model=_DEFAULT_GENERATE_SERVICE_MODEL, |
647 | | - answer_style=answer_style, |
648 | | - safety_settings=safety_settings, |
649 | | - temperature=temperature, |
650 | | - inline_passages=genai.GroundingPassages( |
651 | | - passages=[ |
652 | | - genai.GroundingPassage( |
653 | | - # IDs here takes alphanumeric only. No dashes allowed. |
654 | | - id=str(index), |
655 | | - content=genai.Content(parts=[genai.Part(text=chunk)]), |
656 | | - ) |
657 | | - for index, chunk in enumerate(passages) |
658 | | - ] |
659 | | - ), |
660 | | - ) |
661 | | - ) |
662 | | - |
663 | | - if response.answer.finish_reason != genai.Candidate.FinishReason.STOP: |
664 | | - finish_message = _get_finish_message(response.answer) |
665 | | - raise GenerateAnswerError( |
666 | | - finish_reason=response.answer.finish_reason, |
667 | | - finish_message=finish_message, |
668 | | - safety_ratings=response.answer.safety_ratings, |
669 | | - ) |
670 | | - |
671 | | - assert len(response.answer.content.parts) == 1 |
672 | | - return GroundedAnswer( |
673 | | - answer=response.answer.content.parts[0].text, |
674 | | - attributed_passages=[ |
675 | | - Passage( |
676 | | - text=passage.content.parts[0].text, |
677 | | - id=passage.source_id.grounding_passage.passage_id, |
678 | | - ) |
679 | | - for passage in response.answer.grounding_attributions |
680 | | - if len(passage.content.parts) > 0 |
681 | | - ], |
682 | | - answerable_probability=response.answerable_probability, |
683 | | - ) |
684 | | - |
685 | | - |
686 | | -def _get_finish_message(candidate: genai.Candidate) -> str: |
687 | | - """Get a human-readable finish message from the candidate. |
688 | | -
|
689 | | - Uses the official finish_message field if available, otherwise falls back |
690 | | - to a manual mapping of finish reasons to descriptive messages. |
691 | | - """ |
692 | | - # Use the official field when available |
693 | | - if hasattr(candidate, "finish_message") and candidate.finish_message: |
694 | | - return candidate.finish_message |
695 | | - |
696 | | - # Fallback to manual mapping for all known finish reasons |
697 | | - finish_messages: dict[int, str] = { |
698 | | - genai.Candidate.FinishReason.STOP: "Generation completed successfully", |
699 | | - genai.Candidate.FinishReason.MAX_TOKENS: ( |
700 | | - "Maximum token in context window reached" |
701 | | - ), |
702 | | - genai.Candidate.FinishReason.SAFETY: "Blocked because of safety", |
703 | | - genai.Candidate.FinishReason.RECITATION: "Blocked because of recitation", |
704 | | - genai.Candidate.FinishReason.LANGUAGE: "Unsupported language detected", |
705 | | - genai.Candidate.FinishReason.BLOCKLIST: "Content hit forbidden terms", |
706 | | - genai.Candidate.FinishReason.PROHIBITED_CONTENT: ( |
707 | | - "Inappropriate content detected" |
708 | | - ), |
709 | | - genai.Candidate.FinishReason.SPII: "Sensitive personal information detected", |
710 | | - genai.Candidate.FinishReason.IMAGE_SAFETY: "Image safety violation", |
711 | | - genai.Candidate.FinishReason.MALFORMED_FUNCTION_CALL: "Malformed function call", |
712 | | - genai.Candidate.FinishReason.UNEXPECTED_TOOL_CALL: "Unexpected tool call", |
713 | | - genai.Candidate.FinishReason.OTHER: "Other generation issue", |
714 | | - genai.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED: ( |
715 | | - "Unspecified finish reason" |
716 | | - ), |
717 | | - } |
718 | | - |
719 | | - finish_reason = candidate.finish_reason |
720 | | - return finish_messages.get(finish_reason, "Unexpected generation error") |
721 | | - |
722 | | - |
723 | 601 | def _convert_to_metadata(metadata: dict[str, Any]) -> list[genai.CustomMetadata]: |
724 | 602 | cs: list[genai.CustomMetadata] = [] |
725 | 603 | for key, value in metadata.items(): |
|
0 commit comments