Skip to content

Commit 3c3c661

Browse files
Add sample test case for asyncio file uploads
1 parent 668c520 commit 3c3c661

File tree

3 files changed

+28
-87
lines changed

3 files changed

+28
-87
lines changed

google/generativeai/files.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
from google.generativeai.client import get_default_file_client
3030

31-
__all__ = ["upload_file", "get_file", "list_files", "delete_file"]
31+
__all__ = ["upload_file", "get_file", "list_files", "delete_file",
32+
"upload_file_async", "get_file_async", "list_files_async", "delete_file_async"]
3233

3334
mimetypes.add_type("image/webp", ".webp")
3435

samples/files.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import unittest
1516
from absl.testing import absltest
1617

1718
import google
1819
import google.generativeai as genai
1920
import pathlib
21+
import tempfile
22+
import asyncio
2023

2124
media = pathlib.Path(__file__).parents[1] / "third_party"
2225

@@ -127,5 +130,28 @@ def test_files_delete(self):
127130
# [END files_delete]
128131

129132

133+
class AsyncTests(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
134+
async def test_upload_file_async(self):
135+
import google.generativeai.files as files
136+
tempdir = pathlib.Path(tempfile.mkdtemp())
137+
results = []
138+
async def create_and_upload_file(n: int):
139+
fname = tempdir / str(n)
140+
fname.write_text(str(n))
141+
file_obj = await files.upload_file_async(fname, mime_type='text/plain')
142+
results.append(file_obj)
143+
144+
tasks = []
145+
for n in range(5):
146+
tasks.append(asyncio.create_task(create_and_upload_file(n)))
147+
148+
for task in tasks:
149+
await task
150+
151+
self.assertLen(results, 5)
152+
self.assertEqual(sorted(int(f.display_name) for f in results),
153+
list(range(5)))
154+
155+
130156
if __name__ == "__main__":
131157
absltest.main()

tests/test_files_async.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

0 commit comments

Comments
 (0)