Skip to content

Commit b3a35c1

Browse files
Support for new TaskTypes and output_dimensionality param (#285)
* Update Task Types for embeddings * Add support for param * format * fix typo * Add guard against negative ouput_dim * Update docs * async code match * update tests * format * Update task types
1 parent 50f8c12 commit b3a35c1

File tree

3 files changed

+88
-6
lines changed

3 files changed

+88
-6
lines changed

google/generativeai/embedding.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@
5959
EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING,
6060
5: EmbeddingTaskType.CLUSTERING,
6161
"clustering": EmbeddingTaskType.CLUSTERING,
62+
6: EmbeddingTaskType.QUESTION_ANSWERING,
63+
"question_answering": EmbeddingTaskType.QUESTION_ANSWERING,
64+
"qa": EmbeddingTaskType.QUESTION_ANSWERING,
65+
EmbeddingTaskType.QUESTION_ANSWERING: EmbeddingTaskType.QUESTION_ANSWERING,
66+
7: EmbeddingTaskType.FACT_VERIFICATION,
67+
"fact_verification": EmbeddingTaskType.FACT_VERIFICATION,
68+
"verification": EmbeddingTaskType.FACT_VERIFICATION,
69+
EmbeddingTaskType.FACT_VERIFICATION: EmbeddingTaskType.FACT_VERIFICATION,
6270
}
6371

6472

@@ -94,6 +102,7 @@ def embed_content(
94102
content: content_types.ContentType,
95103
task_type: EmbeddingTaskTypeOptions | None = None,
96104
title: str | None = None,
105+
output_dimensionality: int | None = None,
97106
client: glm.GenerativeServiceClient | None = None,
98107
request_options: dict[str, Any] | None = None,
99108
) -> text_types.EmbeddingDict: ...
@@ -105,6 +114,7 @@ def embed_content(
105114
content: Iterable[content_types.ContentType],
106115
task_type: EmbeddingTaskTypeOptions | None = None,
107116
title: str | None = None,
117+
output_dimensionality: int | None = None,
108118
client: glm.GenerativeServiceClient | None = None,
109119
request_options: dict[str, Any] | None = None,
110120
) -> text_types.BatchEmbeddingDict: ...
@@ -115,6 +125,7 @@ def embed_content(
115125
content: content_types.ContentType | Iterable[content_types.ContentType],
116126
task_type: EmbeddingTaskTypeOptions | None = None,
117127
title: str | None = None,
128+
output_dimensionality: int | None = None,
118129
client: glm.GenerativeServiceClient = None,
119130
request_options: dict[str, Any] | None = None,
120131
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
@@ -135,6 +146,12 @@ def embed_content(
135146
title:
136147
An optional title for the text. Only applicable when task_type is
137148
`RETRIEVAL_DOCUMENT`.
149+
150+
output_dimensionality:
151+
Optional reduced dimensionality for the output embeddings. If set,
152+
excessive values from the output embeddings will be truncated from
153+
the end.
154+
138155
request_options:
139156
Options for the request.
140157
@@ -155,14 +172,21 @@ def embed_content(
155172
"If a title is specified, the task must be a retrieval document type task."
156173
)
157174

175+
if output_dimensionality and output_dimensionality < 0:
176+
raise ValueError("`output_dimensionality` must be a non-negative integer.")
177+
158178
if task_type:
159179
task_type = to_task_type(task_type)
160180

161181
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
162182
result = {"embedding": []}
163183
requests = (
164184
glm.EmbedContentRequest(
165-
model=model, content=content_types.to_content(c), task_type=task_type, title=title
185+
model=model,
186+
content=content_types.to_content(c),
187+
task_type=task_type,
188+
title=title,
189+
output_dimensionality=output_dimensionality,
166190
)
167191
for c in content
168192
)
@@ -177,7 +201,11 @@ def embed_content(
177201
return result
178202
else:
179203
embedding_request = glm.EmbedContentRequest(
180-
model=model, content=content_types.to_content(content), task_type=task_type, title=title
204+
model=model,
205+
content=content_types.to_content(content),
206+
task_type=task_type,
207+
title=title,
208+
output_dimensionality=output_dimensionality,
181209
)
182210
embedding_response = client.embed_content(
183211
embedding_request,
@@ -194,6 +222,7 @@ async def embed_content_async(
194222
content: content_types.ContentType,
195223
task_type: EmbeddingTaskTypeOptions | None = None,
196224
title: str | None = None,
225+
output_dimensionality: int | None = None,
197226
client: glm.GenerativeServiceAsyncClient | None = None,
198227
request_options: dict[str, Any] | None = None,
199228
) -> text_types.EmbeddingDict: ...
@@ -205,6 +234,7 @@ async def embed_content_async(
205234
content: Iterable[content_types.ContentType],
206235
task_type: EmbeddingTaskTypeOptions | None = None,
207236
title: str | None = None,
237+
output_dimensionality: int | None = None,
208238
client: glm.GenerativeServiceAsyncClient | None = None,
209239
request_options: dict[str, Any] | None = None,
210240
) -> text_types.BatchEmbeddingDict: ...
@@ -215,6 +245,7 @@ async def embed_content_async(
215245
content: content_types.ContentType | Iterable[content_types.ContentType],
216246
task_type: EmbeddingTaskTypeOptions | None = None,
217247
title: str | None = None,
248+
output_dimensionality: int | None = None,
218249
client: glm.GenerativeServiceAsyncClient = None,
219250
request_options: dict[str, Any] | None = None,
220251
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
@@ -232,14 +263,21 @@ async def embed_content_async(
232263
"If a title is specified, the task must be a retrieval document type task."
233264
)
234265

266+
if output_dimensionality and output_dimensionality < 0:
267+
raise ValueError("`output_dimensionality` must be a non-negative integer.")
268+
235269
if task_type:
236270
task_type = to_task_type(task_type)
237271

238272
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
239273
result = {"embedding": []}
240274
requests = (
241275
glm.EmbedContentRequest(
242-
model=model, content=content_types.to_content(c), task_type=task_type, title=title
276+
model=model,
277+
content=content_types.to_content(c),
278+
task_type=task_type,
279+
title=title,
280+
output_dimensionality=output_dimensionality,
243281
)
244282
for c in content
245283
)
@@ -254,7 +292,11 @@ async def embed_content_async(
254292
return result
255293
else:
256294
embedding_request = glm.EmbedContentRequest(
257-
model=model, content=content_types.to_content(content), task_type=task_type, title=title
295+
model=model,
296+
content=content_types.to_content(content),
297+
task_type=task_type,
298+
title=title,
299+
output_dimensionality=output_dimensionality,
258300
)
259301
embedding_response = await client.embed_content(
260302
embedding_request,

tests/test_embedding.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,14 @@ def test_embed_content_title_and_task_2(self):
122122
text = "What are you?"
123123
with self.assertRaises(ValueError):
124124
embedding.embed_content(
125-
model=DEFAULT_EMB_MODEL, content=text, task_type="unspecified", title="Exploring AI"
125+
model=DEFAULT_EMB_MODEL, content=text, task_type="similarity", title="Exploring AI"
126126
)
127127

128+
def test_embed_content_with_negative_output_dimensionality(self):
129+
text = "What are you?"
130+
with self.assertRaises(ValueError):
131+
embedding.embed_content(model=DEFAULT_EMB_MODEL, content=text, output_dimensionality=-1)
132+
128133
def test_generate_answer_called_with_request_options(self):
129134
self.client.embed_content = mock.MagicMock()
130135
request = mock.ANY
@@ -174,6 +179,34 @@ def test_embed_content_called_with_request_options(self):
174179

175180
self.client.embed_content.assert_called_once_with(request, **request_options)
176181

182+
@parameterized.named_parameters(
183+
dict(
184+
testcase_name="embedding.embed_content",
185+
obj=embedding.embed_content,
186+
aobj=embedding.embed_content_async,
187+
),
188+
)
189+
def test_async_code_match(self, obj, aobj):
190+
import inspect
191+
import re
192+
193+
source = inspect.getsource(obj)
194+
asource = inspect.getsource(aobj)
195+
source = re.sub('""".*"""', "", source, flags=re.DOTALL)
196+
asource = re.sub('""".*"""', "", asource, flags=re.DOTALL)
197+
asource = (
198+
asource.replace("anext", "next")
199+
.replace("aiter", "iter")
200+
.replace("_async", "")
201+
.replace("async ", "")
202+
.replace("await ", "")
203+
.replace("Async", "")
204+
.replace("ASYNC_", "")
205+
)
206+
207+
asource = re.sub(" *?# type: ignore", "", asource)
208+
self.assertEqual(source, asource)
209+
177210

178211
if __name__ == "__main__":
179212
absltest.main()

tests/test_embedding_async.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,14 @@ async def test_embed_content_async_title_and_task_2(self):
121121
text = "What are you?"
122122
with self.assertRaises(ValueError):
123123
await embedding.embed_content_async(
124-
model=DEFAULT_EMB_MODEL, content=text, task_type="unspecified", title="Exploring AI"
124+
model=DEFAULT_EMB_MODEL, content=text, task_type="similarity", title="Exploring AI"
125+
)
126+
127+
async def test_embed_content_with_negative_output_dimensionality(self):
128+
text = "What are you?"
129+
with self.assertRaises(ValueError):
130+
await embedding.embed_content_async(
131+
model=DEFAULT_EMB_MODEL, content=text, output_dimensionality=-1
125132
)
126133

127134
async def test_embed_content_called_with_request_options(self):

0 commit comments

Comments
 (0)