Skip to content

Commit 780b9e0

Browse files
Add main async code match test (#299)
* Remove f-string from doc * Add async_code_match * Remove async code match for separate modules * blacken * Remove looping for working_dirs * fpath.name is same as split * Move decorator checks to separate func * blacken * fix unbound param
1 parent c206dbb commit 780b9e0

File tree

5 files changed

+113
-207
lines changed

5 files changed

+113
-207
lines changed

google/generativeai/answer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,9 @@ def generate_answer(
249249
client: glm.GenerativeServiceClient | None = None,
250250
request_options: dict[str, Any] | None = None,
251251
):
252-
f"""
252+
"""
253253
Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
254-
254+
255255
You can pass a literal list of text chunks:
256256
257257
>>> from google.generativeai import answer
@@ -320,6 +320,7 @@ async def generate_answer_async(
320320
safety_settings: safety_types.SafetySettingOptions | None = None,
321321
temperature: float | None = None,
322322
client: glm.GenerativeServiceClient | None = None,
323+
request_options: dict[str, Any] | None = None,
323324
):
324325
"""
325326
Calls the API and returns a `types.Answer` containing the answer.
@@ -341,6 +342,9 @@ async def generate_answer_async(
341342
Returns:
342343
A `types.Answer` containing the model's text answer response.
343344
"""
345+
if request_options is None:
346+
request_options = {}
347+
344348
if client is None:
345349
client = get_default_generative_async_client()
346350

@@ -354,6 +358,6 @@ async def generate_answer_async(
354358
answer_style=answer_style,
355359
)
356360

357-
response = await client.generate_answer(request)
361+
response = await client.generate_answer(request, **request_options)
358362

359363
return response

tests/test_async_code_match.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pathlib
17+
import ast
18+
import typing
19+
import re
20+
21+
from absl.testing import absltest
22+
from absl.testing import parameterized
23+
24+
EXEMPT_DIRS = ["notebook"]
25+
EXEMPT_DECORATORS = ["overload", "property", "setter", "abstractmethod", "staticmethod"]
26+
EXEMPT_FILES = ["client.py", "version.py", "discuss.py", "files.py"]
27+
EXEMPT_FUNCTIONS = ["to_dict", "_to_proto", "to_proto", "from_proto", "from_dict", "_from_dict"]
28+
29+
30+
class CodeMatch(absltest.TestCase):
31+
32+
def _maybe_trim_docstring(self, node):
33+
if (
34+
node.body
35+
and isinstance(node.body[0], ast.Expr)
36+
and isinstance(node.body[0].value, ast.Constant)
37+
):
38+
node.body = node.body[1:]
39+
40+
return ast.unparse(node)
41+
42+
def _inspect_decorator_exemption(self, node, fpath) -> bool:
43+
for decorator in node.decorator_list:
44+
if isinstance(decorator, ast.Attribute):
45+
if decorator.attr in EXEMPT_DECORATORS:
46+
return True
47+
elif isinstance(decorator, ast.Name):
48+
if decorator.id in EXEMPT_DECORATORS:
49+
return True
50+
elif isinstance(decorator, ast.Call):
51+
if decorator.func.attr in EXEMPT_DECORATORS:
52+
return True
53+
else:
54+
raise TypeError(
55+
f"Unknown decorator type {decorator}, during checking {node.name} from {fpath.name}"
56+
)
57+
58+
return False
59+
60+
def _execute_code_match(self, source, asource):
61+
asource = (
62+
asource.replace("anext", "next")
63+
.replace("aiter", "iter")
64+
.replace("_async", "")
65+
.replace("async ", "")
66+
.replace("await ", "")
67+
.replace("Async", "")
68+
.replace("ASYNC_", "")
69+
)
70+
asource = re.sub(" *?# type: ignore", "", asource)
71+
self.assertEqual(source, asource)
72+
73+
def test_code_match_for_async_methods(self):
74+
for fpath in (pathlib.Path(__file__).parent.parent / "google").rglob("*.py"):
75+
if fpath.name in EXEMPT_FILES or any([d in fpath.parts for d in EXEMPT_DIRS]):
76+
continue
77+
# print(f"Checking {fpath.absolute()}")
78+
code_match_funcs: dict[str, ast.AST] = {}
79+
source = fpath.read_text()
80+
source_nodes = ast.parse(source)
81+
82+
for node in ast.walk(source_nodes):
83+
if isinstance(
84+
node, (ast.FunctionDef, ast.AsyncFunctionDef)
85+
) and not node.name.startswith("_"):
86+
name = node.name[:-6] if node.name.endswith("_async") else node.name
87+
if name in EXEMPT_FUNCTIONS or self._inspect_decorator_exemption(node, fpath):
88+
continue
89+
# print(f"Checking {node.name}")
90+
91+
if func_name := code_match_funcs.pop(name, None):
92+
snode, anode = (
93+
(func_name, node)
94+
if isinstance(node, ast.AsyncFunctionDef)
95+
else (node, func_name)
96+
)
97+
func_source = self._maybe_trim_docstring(snode)
98+
func_asource = self._maybe_trim_docstring(anode)
99+
self._execute_code_match(func_source, func_asource)
100+
# print(f"Matched {node.name}")
101+
else:
102+
code_match_funcs[node.name] = node
103+
104+
105+
if __name__ == "__main__":
106+
absltest.main()

tests/test_embedding.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -179,34 +179,6 @@ def test_embed_content_called_with_request_options(self):
179179

180180
self.client.embed_content.assert_called_once_with(request, **request_options)
181181

182-
@parameterized.named_parameters(
183-
dict(
184-
testcase_name="embedding.embed_content",
185-
obj=embedding.embed_content,
186-
aobj=embedding.embed_content_async,
187-
),
188-
)
189-
def test_async_code_match(self, obj, aobj):
190-
import inspect
191-
import re
192-
193-
source = inspect.getsource(obj)
194-
asource = inspect.getsource(aobj)
195-
source = re.sub('""".*"""', "", source, flags=re.DOTALL)
196-
asource = re.sub('""".*"""', "", asource, flags=re.DOTALL)
197-
asource = (
198-
asource.replace("anext", "next")
199-
.replace("aiter", "iter")
200-
.replace("_async", "")
201-
.replace("async ", "")
202-
.replace("await ", "")
203-
.replace("Async", "")
204-
.replace("ASYNC_", "")
205-
)
206-
207-
asource = re.sub(" *?# type: ignore", "", asource)
208-
self.assertEqual(source, asource)
209-
210182

211183
if __name__ == "__main__":
212184
absltest.main()

tests/test_permission.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -173,59 +173,6 @@ def test_update_permission_failure_restricted_update_path(self):
173173
{"grantee_type": permission_services.to_grantee_type("user")}
174174
)
175175

176-
@parameterized.named_parameters(
177-
[
178-
"create_permission",
179-
retriever_services.Corpus.create_permission,
180-
retriever_services.Corpus.create_permission_async,
181-
],
182-
[
183-
"list_permissions",
184-
retriever_services.Corpus.list_permissions,
185-
retriever_services.Corpus.list_permissions_async,
186-
],
187-
[
188-
"Permission.delete",
189-
permission_services.Permission.delete,
190-
permission_services.Permission.delete_async,
191-
],
192-
[
193-
"Permission.update",
194-
permission_services.Permission.update,
195-
permission_services.Permission.update_async,
196-
],
197-
[
198-
"Permission.get_permission",
199-
permission_services.Permission.get,
200-
permission_services.Permission.get_async,
201-
],
202-
[
203-
"permission.get_permission",
204-
permission.get_permission,
205-
permission.get_permission_async,
206-
],
207-
)
208-
def test_async_code_match(self, obj, aobj):
209-
import inspect
210-
import re
211-
212-
source = inspect.getsource(obj)
213-
asource = inspect.getsource(aobj)
214-
source = re.sub('""".*"""', "", source, flags=re.DOTALL)
215-
asource = re.sub('""".*"""', "", asource, flags=re.DOTALL)
216-
asource = (
217-
asource.replace("anext", "next")
218-
.replace("aiter", "iter")
219-
.replace("_async", "")
220-
.replace("async ", "")
221-
.replace("await ", "")
222-
.replace("Async", "")
223-
.replace("ASYNC_", "")
224-
)
225-
226-
asource = re.sub(" *?# type: ignore", "", asource)
227-
self.assertEqual(source, asource)
228-
229176
def test_create_corpus_called_with_request_options(self):
230177
self.client.create_corpus = unittest.mock.MagicMock()
231178
request = unittest.mock.ANY

tests/test_retriever.py

Lines changed: 0 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -647,129 +647,6 @@ def test_batch_delete_chunks(self):
647647
delete_request = demo_document.batch_delete_chunks(chunks=[x.name, y.name])
648648
self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest)
649649

650-
@parameterized.named_parameters(
651-
[
652-
"create_corpus",
653-
retriever.create_corpus,
654-
retriever.create_corpus_async,
655-
],
656-
[
657-
"get_corpus",
658-
retriever.get_corpus,
659-
retriever.get_corpus_async,
660-
],
661-
[
662-
"delete_corpus",
663-
retriever.delete_corpus,
664-
retriever.delete_corpus_async,
665-
],
666-
[
667-
"list_corpora",
668-
retriever.list_corpora,
669-
retriever.list_corpora_async,
670-
],
671-
[
672-
"Corpus.create_document",
673-
retriever_service.Corpus.create_document,
674-
retriever_service.Corpus.create_document_async,
675-
],
676-
[
677-
"Corpus.get_document",
678-
retriever_service.Corpus.get_document,
679-
retriever_service.Corpus.get_document_async,
680-
],
681-
[
682-
"Corpus.update",
683-
retriever_service.Corpus.update,
684-
retriever_service.Corpus.update_async,
685-
],
686-
[
687-
"Corpus.query",
688-
retriever_service.Corpus.query,
689-
retriever_service.Corpus.query_async,
690-
],
691-
[
692-
"Corpus.list_documents",
693-
retriever_service.Corpus.list_documents,
694-
retriever_service.Corpus.list_documents_async,
695-
],
696-
[
697-
"Corpus.delete_document",
698-
retriever_service.Corpus.delete_document,
699-
retriever_service.Corpus.delete_document_async,
700-
],
701-
[
702-
"Document.create_chunk",
703-
retriever_service.Document.create_chunk,
704-
retriever_service.Document.create_chunk_async,
705-
],
706-
[
707-
"Document.get_chunk",
708-
retriever_service.Document.get_chunk,
709-
retriever_service.Document.get_chunk_async,
710-
],
711-
[
712-
"Document.batch_create_chunks",
713-
retriever_service.Document.batch_create_chunks,
714-
retriever_service.Document.batch_create_chunks_async,
715-
],
716-
[
717-
"Document.list_chunks",
718-
retriever_service.Document.list_chunks,
719-
retriever_service.Document.list_chunks_async,
720-
],
721-
[
722-
"Document.query",
723-
retriever_service.Document.query,
724-
retriever_service.Document.query_async,
725-
],
726-
[
727-
"Document.update",
728-
retriever_service.Document.update,
729-
retriever_service.Document.update_async,
730-
],
731-
[
732-
"Document.batch_update_chunks",
733-
retriever_service.Document.batch_update_chunks,
734-
retriever_service.Document.batch_update_chunks_async,
735-
],
736-
[
737-
"Document.delete_chunk",
738-
retriever_service.Document.delete_chunk,
739-
retriever_service.Document.delete_chunk_async,
740-
],
741-
[
742-
"Document.batch_delete_chunks",
743-
retriever_service.Document.batch_delete_chunks,
744-
retriever_service.Document.batch_delete_chunks_async,
745-
],
746-
[
747-
"Chunk.update",
748-
retriever_service.Chunk.update,
749-
retriever_service.Chunk.update_async,
750-
],
751-
)
752-
def test_async_code_match(self, obj, aobj):
753-
import inspect
754-
import re
755-
756-
source = inspect.getsource(obj)
757-
asource = inspect.getsource(aobj)
758-
source = re.sub('""".*"""', "", source, flags=re.DOTALL)
759-
asource = re.sub('""".*"""', "", asource, flags=re.DOTALL)
760-
asource = (
761-
asource.replace("anext", "next")
762-
.replace("aiter", "iter")
763-
.replace("_async", "")
764-
.replace("async ", "")
765-
.replace("await ", "")
766-
.replace("Async", "")
767-
.replace("ASYNC_", "")
768-
)
769-
770-
asource = re.sub(" *?# type: ignore", "", asource)
771-
self.assertEqual(source, asource)
772-
773650
@parameterized.parameters(
774651
{"method": "create_corpus"},
775652
{"method": "get_corpus"},

0 commit comments

Comments
 (0)