diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 938018b96..d2eb6b1c9 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -8,6 +8,7 @@ from typing import Any, cast from collections.abc import Sequence import httplib2 +from io import IOBase import google.ai.generativelanguage as glm import google.generativeai.protos as protos @@ -88,7 +89,7 @@ def _setup_discovery_api(self, metadata: dict | Sequence[tuple[str, str]] = ()): def create_file( self, - path: str | pathlib.Path | os.PathLike, + path: str | pathlib.Path | os.PathLike | IOBase, *, mime_type: str | None = None, name: str | None = None, @@ -105,9 +106,15 @@ def create_file( if display_name is not None: file["displayName"] = display_name - media = googleapiclient.http.MediaFileUpload( - filename=path, mimetype=mime_type, resumable=resumable - ) + if isinstance(path, IOBase): + media = googleapiclient.http.MediaIoBaseUpload( + fd=path, mimetype=mime_type, resumable=resumable + ) + else: + media = googleapiclient.http.MediaFileUpload( + filename=path, mimetype=mime_type, resumable=resumable + ) + request = self._discovery_api.media().upload(body={"file": file}, media_body=media) for key, value in metadata: request.headers[key] = value diff --git a/google/generativeai/files.py b/google/generativeai/files.py index c0d8e1e0a..b2581bdcd 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -21,6 +21,7 @@ import logging from google.generativeai import protos from itertools import islice +from io import IOBase from google.generativeai.types import file_types @@ -32,7 +33,7 @@ def upload_file( - path: str | pathlib.Path | os.PathLike, + path: str | pathlib.Path | os.PathLike | IOBase, *, mime_type: str | None = None, name: str | None = None, @@ -42,7 +43,7 @@ def upload_file( """Calls the API to upload a file using a supported file service. Args: - path: The path to the file to be uploaded. + path: The path to the file or a file-like object (e.g., BytesIO) to be uploaded. mime_type: The MIME type of the file. If not provided, it will be inferred from the file extension. name: The name of the file in the destination (e.g., 'files/sample-image'). @@ -57,17 +58,30 @@ def upload_file( """ client = get_default_file_client() - path = pathlib.Path(os.fspath(path)) + if isinstance(path, IOBase): + if mime_type is None: + raise ValueError( + "Unknown mime type: When passing a file like object to `path` (instead of a\n" + " path-like object) you must set the `mime_type` argument" + ) + else: + path = pathlib.Path(os.fspath(path)) - if mime_type is None: - mime_type, _ = mimetypes.guess_type(path) + if display_name is None: + display_name = path.name + + if mime_type is None: + mime_type, _ = mimetypes.guess_type(path) + + if mime_type is None: + raise ValueError( + "Unknown mime type: Could not determine the mimetype for your file\n" + " please set the `mime_type` argument" + ) if name is not None and "/" not in name: name = f"files/{name}" - if display_name is None: - display_name = path.name - response = client.create_file( path=path, mime_type=mime_type, name=name, display_name=display_name, resumable=resumable ) diff --git a/samples/files.py b/samples/files.py index cbed68a1e..8f98365aa 100644 --- a/samples/files.py +++ b/samples/files.py @@ -83,6 +83,18 @@ def test_files_create_pdf(self): print(response.text) # [END files_create_pdf] + def test_files_create_from_IO(self): + # [START files_create_io] + # You can pass a file-like object, instead of a path. + # Useful for streaming. + model = genai.GenerativeModel("gemini-1.5-flash") + fpath = media / "test.pdf" + with open(fpath, "rb") as f: + sample_pdf = genai.upload_file(f, mime_type="application/pdf") + response = model.generate_content(["Give me a summary of this pdf file.", sample_pdf]) + print(response.text) + # [END files_create_io] + def test_files_list(self): # [START files_list] print("My files:") diff --git a/tests/test_files.py b/tests/test_files.py index cb48316bd..0f7ca5707 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -18,6 +18,7 @@ import collections import datetime +import io import os from typing import Iterable, Sequence import pathlib @@ -38,7 +39,7 @@ def __init__(self, test): def create_file( self, - path: str | pathlib.Path | os.PathLike, + path: str | io.IOBase | os.PathLike, *, mime_type: str | None = None, name: str | None = None, @@ -102,12 +103,13 @@ def test_video_metadata(self): protos.File( uri="https://test", state="ACTIVE", + mime_type="video/quicktime", video_metadata=dict(video_duration=datetime.timedelta(seconds=30)), error=dict(code=7, message="ok?"), ) ) - f = genai.upload_file(path="dummy") + f = genai.upload_file(path="dummy.mov") self.assertEqual(google.rpc.status_pb2.Status(code=7, message="ok?"), f.error) self.assertEqual( protos.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))),