Skip to content

Commit ccf8a52

Browse files
committed
fix: use file handle in file_from_path to prevent OOM
Additional fix for the _interactions layer - file_from_path was loading entire files with read_bytes() when it can return a file handle instead.
1 parent dfee3d3 commit ccf8a52

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

google/genai/_interactions/_utils/_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def _extract_items(
9292
if is_list(obj):
9393
files: list[tuple[str, FileTypes]] = []
9494
for entry in obj:
95-
assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
95+
assert_is_file_content(
96+
entry, key=flattened_key + "[]" if flattened_key else ""
97+
)
9698
files.append((flattened_key + "[]", cast(FileTypes, entry)))
9799
return files
98100

@@ -132,7 +134,9 @@ def _extract_items(
132134
item,
133135
path,
134136
index=index,
135-
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
137+
flattened_key=flattened_key + "[]"
138+
if flattened_key is not None
139+
else "[]",
136140
)
137141
for item in obj
138142
]
@@ -282,7 +286,12 @@ def wrapper(*args: object, **kwargs: object) -> object:
282286
else: # no break
283287
if len(variants) > 1:
284288
variations = human_join(
285-
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
289+
[
290+
"("
291+
+ human_join([quote(arg) for arg in variant], final="and")
292+
+ ")"
293+
for variant in variants
294+
]
286295
)
287296
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
288297
else:
@@ -380,9 +389,8 @@ def removesuffix(string: str, suffix: str) -> str:
380389

381390

382391
def file_from_path(path: str) -> FileTypes:
383-
contents = Path(path).read_bytes()
384392
file_name = os.path.basename(path)
385-
return (file_name, contents)
393+
return (file_name, open(Path(path), "rb"))
386394

387395

388396
def get_required_header(headers: HeadersLike, header: str) -> str:
@@ -394,7 +402,11 @@ def get_required_header(headers: HeadersLike, header: str) -> str:
394402
return v
395403

396404
# to deal with the case where the header looks like Stainless-Event-Id
397-
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
405+
intercaps_header = re.sub(
406+
r"([^\w])(\w)",
407+
lambda pat: pat.group(1) + pat.group(2).upper(),
408+
header.capitalize(),
409+
)
398410

399411
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
400412
value = headers.get(normalized_header)

0 commit comments

Comments
 (0)