Skip to content

Commit dfee3d3

Browse files
committed
fix: prevent OOM by using file handles and chunked base64 encoding
This change addresses memory spikes that can cause OOM errors when uploading large files to the Gemini File API. Changes: 1. _interactions/_files.py: Return open file handles instead of loading entire files into memory with read_bytes(). httpx supports IO[bytes] directly, so there's no need to pre-load file contents. 2. _interactions/_utils/_transform.py: Implement chunked base64 encoding using 3MB chunks (must be multiple of 3 for base64 correctness) to reduce peak memory usage when encoding files for inline data. The existing chunked upload mechanism in _api_client.py (8MB chunks) was already correct, but files were being loaded into memory before reaching that code path. This fix ensures memory-efficient handling from the start of the upload flow.
1 parent 351e490 commit dfee3d3

File tree

2 files changed

+148
-54
lines changed

2 files changed

+148
-54
lines changed

google/genai/_interactions/_files.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
4242

4343
def is_file_content(obj: object) -> TypeGuard[FileContent]:
4444
return (
45-
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
45+
isinstance(obj, bytes)
46+
or isinstance(obj, tuple)
47+
or isinstance(obj, io.IOBase)
48+
or isinstance(obj, os.PathLike)
4649
)
4750

4851

4952
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
5053
if not is_file_content(obj):
51-
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
54+
prefix = (
55+
f"Expected entry at `{key}`"
56+
if key is not None
57+
else f"Expected file input `{obj!r}`"
58+
)
5259
raise RuntimeError(
5360
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead."
5461
) from None
@@ -71,7 +78,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
7178
elif is_sequence_t(files):
7279
files = [(key, _transform_file(file)) for key, file in files]
7380
else:
74-
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
81+
raise TypeError(
82+
f"Unexpected file type input {type(files)}, expected mapping or sequence"
83+
)
7584

7685
return files
7786

@@ -80,19 +89,23 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
8089
if is_file_content(file):
8190
if isinstance(file, os.PathLike):
8291
path = pathlib.Path(file)
83-
return (path.name, path.read_bytes())
92+
# Return an open file handle instead of loading entire file into memory.
93+
# This prevents OOM errors for large files. httpx supports IO[bytes] directly.
94+
return (path.name, open(path, "rb"))
8495

8596
return file
8697

8798
if is_tuple_t(file):
8899
return (file[0], read_file_content(file[1]), *file[2:])
89100

90-
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
101+
raise TypeError(
102+
f"Expected file types input to be a FileContent type or to be a tuple"
103+
)
91104

92105

93106
def read_file_content(file: FileContent) -> HttpxFileContent:
94107
if isinstance(file, os.PathLike):
95-
return pathlib.Path(file).read_bytes()
108+
return open(pathlib.Path(file), "rb")
96109
return file
97110

98111

@@ -113,27 +126,31 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
113126
elif is_sequence_t(files):
114127
files = [(key, await _async_transform_file(file)) for key, file in files]
115128
else:
116-
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
129+
raise TypeError(
130+
"Unexpected file type input {type(files)}, expected mapping or sequence"
131+
)
117132

118133
return files
119134

120135

121136
async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
122137
if is_file_content(file):
123138
if isinstance(file, os.PathLike):
124-
path = anyio.Path(file)
125-
return (path.name, await path.read_bytes())
139+
path = pathlib.Path(file)
140+
return (path.name, open(path, "rb"))
126141

127142
return file
128143

129144
if is_tuple_t(file):
130145
return (file[0], await async_read_file_content(file[1]), *file[2:])
131146

132-
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
147+
raise TypeError(
148+
f"Expected file types input to be a FileContent type or to be a tuple"
149+
)
133150

134151

135152
async def async_read_file_content(file: FileContent) -> HttpxFileContent:
136153
if isinstance(file, os.PathLike):
137-
return await anyio.Path(file).read_bytes()
154+
return open(pathlib.Path(file), "rb")
138155

139156
return file

google/genai/_interactions/_utils/_transform.py

Lines changed: 120 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
import pathlib
2121
from typing import Any, Mapping, TypeVar, cast
2222
from datetime import date, datetime
23-
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
23+
from typing_extensions import (
24+
Literal,
25+
get_args,
26+
override,
27+
get_type_hints as _get_type_hints,
28+
)
2429

2530
import anyio
2631
import pydantic
@@ -196,15 +201,26 @@ def _transform_recursive(
196201

197202
if origin == dict and is_mapping(data):
198203
items_type = get_args(stripped_type)[1]
199-
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
204+
return {
205+
key: _transform_recursive(value, annotation=items_type)
206+
for key, value in data.items()
207+
}
200208

201209
if (
202210
# List[T]
203211
(is_list_type(stripped_type) and is_list(data))
204212
# Iterable[T]
205-
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
213+
or (
214+
is_iterable_type(stripped_type)
215+
and is_iterable(data)
216+
and not isinstance(data, str)
217+
)
206218
# Sequence[T]
207-
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
219+
or (
220+
is_sequence_type(stripped_type)
221+
and is_sequence(data)
222+
and not isinstance(data, str)
223+
)
208224
):
209225
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
210226
# intended as an iterable, so we don't transform it.
@@ -221,7 +237,10 @@ def _transform_recursive(
221237
return data
222238
return list(data)
223239

224-
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
240+
return [
241+
_transform_recursive(d, annotation=annotation, inner_type=inner_type)
242+
for d in data
243+
]
225244

226245
if is_union_type(stripped_type):
227246
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
@@ -248,7 +267,9 @@ def _transform_recursive(
248267
return data
249268

250269

251-
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
270+
def _format_data(
271+
data: object, format_: PropertyFormat, format_template: str | None
272+
) -> object:
252273
if isinstance(data, (date, datetime)):
253274
if format_ == "iso8601":
254275
return data.isoformat()
@@ -257,22 +278,35 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
257278
return data.strftime(format_template)
258279

259280
if format_ == "base64" and is_base64_file_input(data):
260-
binary: str | bytes | None = None
261-
262-
if isinstance(data, pathlib.Path):
263-
binary = data.read_bytes()
264-
elif isinstance(data, io.IOBase):
265-
binary = data.read()
281+
return _encode_file_to_base64(data)
266282

267-
if isinstance(binary, str): # type: ignore[unreachable]
268-
binary = binary.encode()
269-
270-
if not isinstance(binary, bytes):
271-
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
283+
return data
272284

273-
return base64.b64encode(binary).decode("ascii")
274285

275-
return data
286+
def _encode_file_to_base64(data: object) -> str:
287+
"""Encode file content to base64 using chunked reading to reduce peak memory usage."""
288+
CHUNK_SIZE = 3 * 1024 * 1024 # 3MB (must be multiple of 3 for base64)
289+
chunks: list[str] = []
290+
291+
if isinstance(data, pathlib.Path):
292+
with open(data, "rb") as f:
293+
while True:
294+
chunk = f.read(CHUNK_SIZE)
295+
if not chunk:
296+
break
297+
chunks.append(base64.b64encode(chunk).decode("ascii"))
298+
elif isinstance(data, io.IOBase):
299+
while True:
300+
chunk = data.read(CHUNK_SIZE)
301+
if not chunk:
302+
break
303+
if isinstance(chunk, str):
304+
chunk = chunk.encode()
305+
chunks.append(base64.b64encode(chunk).decode("ascii"))
306+
else:
307+
raise RuntimeError(f"Could not read bytes from {data}; Received {type(data)}")
308+
309+
return "".join(chunks)
276310

277311

278312
def _transform_typeddict(
@@ -292,7 +326,9 @@ def _transform_typeddict(
292326
# we do not have a type annotation for this field, leave it as is
293327
result[key] = value
294328
else:
295-
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
329+
result[_maybe_transform_key(key, type_)] = _transform_recursive(
330+
value, annotation=type_
331+
)
296332
return result
297333

298334

@@ -328,7 +364,9 @@ class Params(TypedDict, total=False):
328364
329365
It should be noted that the transformations that this function does are not represented in the type system.
330366
"""
331-
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
367+
transformed = await _async_transform_recursive(
368+
data, annotation=cast(type, expected_type)
369+
)
332370
return cast(_T, transformed)
333371

334372

@@ -362,15 +400,26 @@ async def _async_transform_recursive(
362400

363401
if origin == dict and is_mapping(data):
364402
items_type = get_args(stripped_type)[1]
365-
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
403+
return {
404+
key: _transform_recursive(value, annotation=items_type)
405+
for key, value in data.items()
406+
}
366407

367408
if (
368409
# List[T]
369410
(is_list_type(stripped_type) and is_list(data))
370411
# Iterable[T]
371-
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
412+
or (
413+
is_iterable_type(stripped_type)
414+
and is_iterable(data)
415+
and not isinstance(data, str)
416+
)
372417
# Sequence[T]
373-
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
418+
or (
419+
is_sequence_type(stripped_type)
420+
and is_sequence(data)
421+
and not isinstance(data, str)
422+
)
374423
):
375424
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
376425
# intended as an iterable, so we don't transform it.
@@ -387,15 +436,22 @@ async def _async_transform_recursive(
387436
return data
388437
return list(data)
389438

390-
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
439+
return [
440+
await _async_transform_recursive(
441+
d, annotation=annotation, inner_type=inner_type
442+
)
443+
for d in data
444+
]
391445

392446
if is_union_type(stripped_type):
393447
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
394448
#
395449
# TODO: there may be edge cases where the same normalized field name will transform to two different names
396450
# in different subtypes.
397451
for subtype in get_args(stripped_type):
398-
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
452+
data = await _async_transform_recursive(
453+
data, annotation=annotation, inner_type=subtype
454+
)
399455
return data
400456

401457
if isinstance(data, pydantic.BaseModel):
@@ -409,12 +465,16 @@ async def _async_transform_recursive(
409465
annotations = get_args(annotated_type)[1:]
410466
for annotation in annotations:
411467
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
412-
return await _async_format_data(data, annotation.format, annotation.format_template)
468+
return await _async_format_data(
469+
data, annotation.format, annotation.format_template
470+
)
413471

414472
return data
415473

416474

417-
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
475+
async def _async_format_data(
476+
data: object, format_: PropertyFormat, format_template: str | None
477+
) -> object:
418478
if isinstance(data, (date, datetime)):
419479
if format_ == "iso8601":
420480
return data.isoformat()
@@ -423,22 +483,35 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
423483
return data.strftime(format_template)
424484

425485
if format_ == "base64" and is_base64_file_input(data):
426-
binary: str | bytes | None = None
427-
428-
if isinstance(data, pathlib.Path):
429-
binary = await anyio.Path(data).read_bytes()
430-
elif isinstance(data, io.IOBase):
431-
binary = data.read()
432-
433-
if isinstance(binary, str): # type: ignore[unreachable]
434-
binary = binary.encode()
486+
return await _async_encode_file_to_base64(data)
435487

436-
if not isinstance(binary, bytes):
437-
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
488+
return data
438489

439-
return base64.b64encode(binary).decode("ascii")
440490

441-
return data
491+
async def _async_encode_file_to_base64(data: object) -> str:
492+
"""Encode file content to base64 using chunked reading to reduce peak memory usage."""
493+
CHUNK_SIZE = 3 * 1024 * 1024 # 3MB (must be multiple of 3 for base64)
494+
chunks: list[str] = []
495+
496+
if isinstance(data, pathlib.Path):
497+
async with await anyio.Path(data).open("rb") as f:
498+
while True:
499+
chunk = await f.read(CHUNK_SIZE)
500+
if not chunk:
501+
break
502+
chunks.append(base64.b64encode(chunk).decode("ascii"))
503+
elif isinstance(data, io.IOBase):
504+
while True:
505+
chunk = data.read(CHUNK_SIZE)
506+
if not chunk:
507+
break
508+
if isinstance(chunk, str):
509+
chunk = chunk.encode()
510+
chunks.append(base64.b64encode(chunk).decode("ascii"))
511+
else:
512+
raise RuntimeError(f"Could not read bytes from {data}; Received {type(data)}")
513+
514+
return "".join(chunks)
442515

443516

444517
async def _async_transform_typeddict(
@@ -458,7 +531,9 @@ async def _async_transform_typeddict(
458531
# we do not have a type annotation for this field, leave it as is
459532
result[key] = value
460533
else:
461-
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
534+
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(
535+
value, annotation=type_
536+
)
462537
return result
463538

464539

@@ -469,4 +544,6 @@ def get_type_hints(
469544
localns: Mapping[str, Any] | None = None,
470545
include_extras: bool = False,
471546
) -> dict[str, Any]:
472-
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
547+
return _get_type_hints(
548+
obj, globalns=globalns, localns=localns, include_extras=include_extras
549+
)

0 commit comments

Comments
 (0)