Skip to content

Commit 6b281da

Browse files
Finish Semantic Retriever. (#230)
* Make naming corpus and document also optional. * Update logic to naming conditions * reformat * Update to add document_name variable * Update to add corpus_name variable * Updated CustomMetadata * Updated CustomMetadata notation * Updating time fields on chunk object * Testing time fields on chunk * Try making times optional params * none otpion to time * Updated batch create chunks, updated time format for list corpora * Reformat * Can accept full name or partial name in get and delete methods * Updated functionality of update methods * Added tests for additional changes * Add custom metadata handling for document * These should be ints. * Adding in semantic retriever config as paramter, add generate_answer_async * Sumplify semantic_retreiver_config * Typo * use .get for non required fields for semantic retriever * Made changes to source parameters accepted, and updates to custom metadata * Testing corpus name change * fixed type for creating corpus, doc, and chunk * testing name changes * testing c_data * remove print statements * accept no name for chunk * loop for name is none for chunk * loop for name is none for chunk * metadata_filters.to_proto() method * Union for SourceNameType * key for source variable * key for source variable * remove isinstance for TypedDict * setting query * convert string query to content * query should not be list of content * change semantic_retriever_config parameter name to semantic_retriever * Testing values of kwargs in _to_proto * check value of kwargs * remove print statements * Update google/generativeai/answer.py Co-authored-by: Mark Daoust <[email protected]> * Fixed custom metadata for chunk * change variable name c to chunk_to_update * Implemented _to_dict() for custom_metadata * Implemented _to_dict() for custom_metadata * use _to_proto instead of _to_dict * accessing DESCRIPTOR * handle custom metadata in update chunks * Change to self.value for string_list_value * Use _to_dict for updates * Use _to_dict for updates * Use _to_dict for updates * Use _to_dict for updates * Use _to_dict for updates * Use dataclasses.asdict for custom metadata * Use dataclasses.asdict for custom metadata * try custom method for processing custom metadata * try print batch update request * try returning document itself * Update google/generativeai/answer.py Co-authored-by: Mark Daoust <[email protected]> * Updated error message for answer pareameters * match batch_update_chunks async method * Fix tests * Fix tests * Fix pytype * Fix pytype * format --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent a187374 commit 6b281da

File tree

6 files changed

+435
-163
lines changed

6 files changed

+435
-163
lines changed

google/generativeai/answer.py

Lines changed: 167 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717
import dataclasses
1818
from collections.abc import Iterable
1919
import itertools
20-
from typing import Iterable, Union, Mapping, Optional, Any
20+
from typing import Any, Iterable, Union, Mapping, Optional, TypedDict
2121

2222
import google.ai.generativelanguage as glm
2323

24-
from google.generativeai.client import get_default_generative_client
24+
from google.generativeai.client import (
25+
get_default_generative_client,
26+
get_default_generative_async_client,
27+
)
2528
from google.generativeai import string_utils
2629
from google.generativeai.types import model_types
2730
from google.generativeai import models
2831
from google.generativeai.types import safety_types
2932
from google.generativeai.types import content_types
3033
from google.generativeai.types import answer_types
34+
from google.generativeai.types import retriever_types
35+
from google.generativeai.types.retriever_types import MetadataFilter
3136

3237
DEFAULT_ANSWER_MODEL = "models/aqa"
3338

@@ -107,11 +112,70 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP
107112
return glm.GroundingPassages(passages=passages)
108113

109114

115+
SourceNameType = Union[
116+
str, retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document
117+
]
118+
119+
120+
class SemanticRetrieverConfigDict(TypedDict):
121+
source: SourceNameType
122+
query: content_types.ContentsType
123+
metadata_filter: Optional[Iterable[MetadataFilter]]
124+
max_chunks_count: Optional[int]
125+
minimum_relevance_score: Optional[float]
126+
127+
128+
SemanticRetrieverConfigOptions = Union[
129+
SourceNameType,
130+
SemanticRetrieverConfigDict,
131+
glm.SemanticRetrieverConfig,
132+
]
133+
134+
135+
def _maybe_get_source_name(source) -> str | None:
136+
if isinstance(source, str):
137+
return source
138+
elif isinstance(
139+
source, (retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document)
140+
):
141+
return source.name
142+
else:
143+
return None
144+
145+
146+
def _make_semantic_retriever_config(
147+
source: SemanticRetrieverConfigOptions,
148+
query: content_types.ContentsType,
149+
) -> glm.SemanticRetrieverConfig:
150+
if isinstance(source, glm.SemanticRetrieverConfig):
151+
return source
152+
153+
name = _maybe_get_source_name(source)
154+
if name is not None:
155+
source = {"source": name}
156+
elif isinstance(source, dict):
157+
source["source"] = _maybe_get_source_name(source["source"])
158+
else:
159+
raise TypeError(
160+
"Could create a `glm.SemanticRetrieverConfig` from:\n"
161+
f" type: {type(source)}\n"
162+
f" value: {source}"
163+
)
164+
165+
if source["query"] is None:
166+
source["query"] = query
167+
elif isinstance(source["query"], str):
168+
source["query"] = content_types.to_content(source["query"])
169+
170+
return glm.SemanticRetrieverConfig(source)
171+
172+
110173
def _make_generate_answer_request(
111174
*,
112175
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
113176
contents: content_types.ContentsType,
114-
grounding_source: GroundingPassagesOptions,
177+
inline_passages: GroundingPassagesOptions | None = None,
178+
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
115179
answer_style: AnswerStyle | None = None,
116180
safety_settings: safety_types.SafetySettingOptions | None = None,
117181
temperature: float | None = None,
@@ -124,7 +188,11 @@ def _make_generate_answer_request(
124188
contents: Content of the current conversation with the model. For single-turn query, this is a
125189
single question to answer. For multi-turn queries, this is a repeated field that contains
126190
conversation history and the last `Content` in the list containing the question.
127-
grounding_source: Sources in which to grounding the answer.
191+
inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs,
192+
or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`,
193+
one must be set, but not both.
194+
semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with
195+
`inline_passages`, one must be set, but not both.
128196
answer_style: Style for grounded answers.
129197
safety_settings: Safety settings for generated output.
130198
temperature: The temperature for randomness in the output.
@@ -141,15 +209,27 @@ def _make_generate_answer_request(
141209
safety_settings, harm_category_set="new"
142210
)
143211

144-
grounding_source = _make_grounding_passages(grounding_source)
212+
if inline_passages is not None and semantic_retriever is not None:
213+
raise ValueError(
214+
"Either `inline_passages` or `semantic_retriever_config` must be set, not both."
215+
)
216+
elif inline_passages is not None:
217+
inline_passages = _make_grounding_passages(inline_passages)
218+
elif semantic_retriever is not None:
219+
semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
220+
else:
221+
TypeError(
222+
f"The source must be either an `inline_passages` xor `semantic_retriever_config`, but both are `None`"
223+
)
145224

146225
if answer_style:
147226
answer_style = to_answer_style(answer_style)
148227

149228
return glm.GenerateAnswerRequest(
150229
model=model,
151230
contents=contents,
152-
inline_passages=grounding_source,
231+
inline_passages=inline_passages,
232+
semantic_retriever=semantic_retriever,
153233
safety_settings=safety_settings,
154234
temperature=temperature,
155235
answer_style=answer_style,
@@ -160,23 +240,48 @@ def generate_answer(
160240
*,
161241
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
162242
contents: content_types.ContentsType,
163-
inline_passages: GroundingPassagesOptions,
243+
inline_passages: GroundingPassagesOptions | None = None,
244+
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
164245
answer_style: AnswerStyle | None = None,
165246
safety_settings: safety_types.SafetySettingOptions | None = None,
166247
temperature: float | None = None,
167248
client: glm.GenerativeServiceClient | None = None,
168249
request_options: dict[str, Any] | None = None,
169250
):
170-
"""
171-
Calls the API and returns a `types.Answer` containing the answer.
251+
f"""
252+
Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
253+
254+
You can pass a literal list of text chunks:
255+
256+
>>> from google.generativeai import answer
257+
>>> answer.generate_answer(
258+
... content=question,
259+
... inline_passages=splitter.split(document)
260+
... )
261+
262+
Or pass a reference to a retreiver Document or Corpus:
263+
264+
>>> from google.generativeai import answer
265+
>>> from google.generativeai import retriever
266+
>>> my_corpus = retriever.get_corpus('my_corpus')
267+
>>> genai.generate_answer(
268+
... content=question,
269+
... semantic_retreiver=my_corpus
270+
... )
271+
172272
173273
Args:
174274
model: Which model to call, as a string or a `types.Model`.
175-
question: The question to be answered by the model, grounded in the
275+
contents: The question to be answered by the model, grounded in the
176276
provided source.
177-
grounding_source: Source indicating the passages in which the answer should be grounded.
277+
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
278+
or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`,
279+
one must be set, but not both.
280+
semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with
281+
`inline_passages`, one must be set, but not both.
178282
answer_style: Style in which the grounded answer should be returned.
179283
safety_settings: Safety settings for generated output. Defaults to None.
284+
temperature: Controls the randomness of the output.
180285
client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead.
181286
request_options: Options for the request.
182287
@@ -192,7 +297,8 @@ def generate_answer(
192297
request = _make_generate_answer_request(
193298
model=model,
194299
contents=contents,
195-
grounding_source=inline_passages,
300+
inline_passages=inline_passages,
301+
semantic_retriever=semantic_retriever,
196302
safety_settings=safety_settings,
197303
temperature=temperature,
198304
answer_style=answer_style,
@@ -201,3 +307,52 @@ def generate_answer(
201307
response = client.generate_answer(request, **request_options)
202308

203309
return response
310+
311+
312+
async def generate_answer_async(
313+
*,
314+
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
315+
contents: content_types.ContentsType,
316+
inline_passages: GroundingPassagesOptions | None = None,
317+
semantic_retriever: SemanticRetrieverConfigOptions | None = None,
318+
answer_style: AnswerStyle | None = None,
319+
safety_settings: safety_types.SafetySettingOptions | None = None,
320+
temperature: float | None = None,
321+
client: glm.GenerativeServiceClient | None = None,
322+
):
323+
"""
324+
Calls the API and returns a `types.Answer` containing the answer.
325+
326+
Args:
327+
model: Which model to call, as a string or a `types.Model`.
328+
contents: The question to be answered by the model, grounded in the
329+
provided source.
330+
inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
331+
or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`,
332+
one must be set, but not both.
333+
semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with
334+
`inline_passages`, one must be set, but not both.
335+
answer_style: Style in which the grounded answer should be returned.
336+
safety_settings: Safety settings for generated output. Defaults to None.
337+
temperature: Controls the randomness of the output.
338+
client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead.
339+
340+
Returns:
341+
A `types.Answer` containing the model's text answer response.
342+
"""
343+
if client is None:
344+
client = get_default_generative_async_client()
345+
346+
request = _make_generate_answer_request(
347+
model=model,
348+
contents=contents,
349+
inline_passages=inline_passages,
350+
semantic_retriever=semantic_retriever,
351+
safety_settings=safety_settings,
352+
temperature=temperature,
353+
answer_style=answer_style,
354+
)
355+
356+
response = await client.generate_answer(request)
357+
358+
return response

google/generativeai/retriever.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929

3030
def create_corpus(
31-
name: Optional[str] = None,
32-
display_name: Optional[str] = None,
31+
name: str | None = None,
32+
display_name: str | None = None,
3333
client: glm.RetrieverServiceClient | None = None,
3434
request_options: dict[str, Any] | None = None,
3535
) -> retriever_types.Corpus:
@@ -58,12 +58,10 @@ def create_corpus(
5858
if client is None:
5959
client = get_default_retriever_client()
6060

61-
corpus, corpus_name = None, None
6261
if name is None:
63-
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
62+
corpus = glm.Corpus(display_name=display_name)
6463
elif retriever_types.valid_name(name):
65-
corpus_name = "corpora/" + name # Construct the name
66-
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
64+
corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name)
6765
else:
6866
raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name))
6967

@@ -77,8 +75,8 @@ def create_corpus(
7775

7876

7977
async def create_corpus_async(
80-
name: Optional[str] = None,
81-
display_name: Optional[str] = None,
78+
name: str | None = None,
79+
display_name: str | None = None,
8280
client: glm.RetrieverServiceAsyncClient | None = None,
8381
request_options: dict[str, Any] | None = None,
8482
) -> retriever_types.Corpus:
@@ -89,12 +87,10 @@ async def create_corpus_async(
8987
if client is None:
9088
client = get_default_retriever_async_client()
9189

92-
corpus, corpus_name = None, None
9390
if name is None:
94-
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
91+
corpus = glm.Corpus(display_name=display_name)
9592
elif retriever_types.valid_name(name):
96-
corpus_name = "corpora/" + name # Construct the name
97-
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
93+
corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name)
9894
else:
9995
raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name))
10096

google/generativeai/types/generation_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ class GenerationConfig:
101101
in the model's specification.
102102
temperature:
103103
Controls the randomness of the output. Note: The
104-
105104
default value varies by model, see the `Model.temperature`
106105
attribute of the `Model` returned the `genai.get_model`
107106
function.

0 commit comments

Comments
 (0)