Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_file(
file["name"] = name
if display_name is not None:
file["displayName"] = display_name

if isinstance(path, IOBase):
media = googleapiclient.http.MediaIoBaseUpload(
fd=path, mimetype=mime_type, resumable=resumable
Expand All @@ -114,7 +114,7 @@ def create_file(
media = googleapiclient.http.MediaFileUpload(
filename=path, mimetype=mime_type, resumable=resumable
)

request = self._discovery_api.media().upload(body={"file": file}, media_body=media)
for key, value in metadata:
request.headers[key] = value
Expand Down
17 changes: 13 additions & 4 deletions google/generativeai/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ def upload_file(
"""
client = get_default_file_client()

if not isinstance(path, IOBase):
if isinstance(path, IOBase):
if mime_type is None:
raise ValueError(
"Unknown mime type: When passing a file like object to `path` (instead of a\n"
" path-like object) you must set the `mime_type` argument"
)
else:
path = pathlib.Path(os.fspath(path))

if display_name is None:
Expand All @@ -67,9 +73,12 @@ def upload_file(
if mime_type is None:
mime_type, _ = mimetypes.guess_type(path)

if mime_type is None:
# Guess failed or IOBase, use octet-stream.
mime_type = 'application/octet-stream'
if mime_type is None:
if mime_type is None:
raise ValueError(
"Unknown mime type: Could not determine the mimetype for your file\n"
" please set the `mime_type` argument"
)

if name is not None and "/" not in name:
name = f"files/{name}"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ def test_video_metadata(self):
protos.File(
uri="https://test",
state="ACTIVE",
mime_type="video/quicktime",
video_metadata=dict(video_duration=datetime.timedelta(seconds=30)),
error=dict(code=7, message="ok?"),
)
)

f = genai.upload_file(path="dummy")
f = genai.upload_file(path="dummy.mov")
self.assertEqual(google.rpc.status_pb2.Status(code=7, message="ok?"), f.error)
self.assertEqual(
protos.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))),
Expand Down
Loading