Skip to content

Commit e680063

Browse files
GWealecopybara-github
authored andcommitted
fix: Fixes a bug that causes intermittent pydantic validation errors when uploading files
The root cause is an unsafe in-memory mutation. The `SaveFilesAsArtifactsPlugin` was saving a direct reference to the message part and then modifying the message list in-place. This created a race condition where downstream code could alter the original part *after* it had been saved as an artifact, leading to a corrupted state. This CL saves a `copy.copy()` of the artifact, which create a snapshot of the data. Also Changes the plugin to return a new `types.Content` object instead of modifying the original message in-place PiperOrigin-RevId: 814308070
1 parent f667c74 commit e680063

File tree

5 files changed

+114
-34
lines changed

5 files changed

+114
-34
lines changed

src/google/adk/models/google_llm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import contextlib
19+
import copy
1920
from functools import cached_property
2021
import logging
2122
import os
@@ -300,8 +301,13 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
300301
if not content.parts:
301302
continue
302303
for part in content.parts:
303-
_remove_display_name_if_present(part.inline_data)
304-
_remove_display_name_if_present(part.file_data)
304+
# Create copies to avoid mutating the original objects
305+
if part.inline_data:
306+
part.inline_data = copy.copy(part.inline_data)
307+
_remove_display_name_if_present(part.inline_data)
308+
if part.file_data:
309+
part.file_data = copy.copy(part.file_data)
310+
_remove_display_name_if_present(part.file_data)
305311

306312
# Initialize config if needed
307313
if llm_request.config and llm_request.config.tools:

src/google/adk/plugins/save_files_as_artifacts_plugin.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import copy
1718
import logging
1819
from typing import Optional
1920

@@ -29,14 +30,15 @@ class SaveFilesAsArtifactsPlugin(BasePlugin):
2930
"""A plugin that saves files embedded in user messages as artifacts.
3031
3132
This is useful to allow users to upload files in the chat experience and have
32-
those files available to the agent.
33-
34-
We use Blob.display_name to determine
35-
the file name. Artifacts with the same name will be overwritten. A placeholder
36-
with the artifact name will be put in place of the embedded file in the user
37-
message so the model knows where to find the file. You may want to add
38-
load_artifacts tool to the agent, or load the artifacts in your own tool to
39-
use the files.
33+
those files available to the agent within the current session.
34+
35+
We use Blob.display_name to determine the file name. By default, artifacts are
36+
session-scoped. For cross-session persistence, prefix the filename with
37+
"user:".
38+
Artifacts with the same name will be overwritten. A placeholder with the
39+
artifact name will be put in place of the embedded file in the user message
40+
so the model knows where to find the file. You may want to add load_artifacts
41+
tool to the agent, or load the artifacts in your own tool to use the files.
4042
"""
4143

4244
def __init__(self, name: str = 'save_files_as_artifacts_plugin'):
@@ -62,10 +64,14 @@ async def on_user_message_callback(
6264
return user_message
6365

6466
if not user_message.parts:
65-
return user_message
67+
return None
68+
69+
new_parts = []
70+
modified = False
6671

6772
for i, part in enumerate(user_message.parts):
6873
if part.inline_data is None:
74+
new_parts.append(part)
6975
continue
7076

7177
try:
@@ -77,23 +83,32 @@ async def on_user_message_callback(
7783
f'No display_name found, using generated filename: {file_name}'
7884
)
7985

86+
# Store original filename for display to user/ placeholder
87+
display_name = file_name
88+
89+
# Create a copy to stop mutation of the saved artifact if the original part is modified
8090
await invocation_context.artifact_service.save_artifact(
8191
app_name=invocation_context.app_name,
8292
user_id=invocation_context.user_id,
8393
session_id=invocation_context.session.id,
8494
filename=file_name,
85-
artifact=part,
95+
artifact=copy.copy(part),
8696
)
8797

88-
# Replace the inline data with a placeholder text
89-
user_message.parts[i] = types.Part(
90-
text=f'[Uploaded Artifact: "{file_name}"]'
98+
# Replace the inline data with a placeholder text (using the clean name)
99+
new_parts.append(
100+
types.Part(text=f'[Uploaded Artifact: "{display_name}"]')
91101
)
102+
modified = True
92103
logger.info(f'Successfully saved artifact: {file_name}')
93104

94105
except Exception as e:
95106
logger.error(f'Failed to save artifact for part {i}: {e}')
96107
# Keep the original part if saving fails
108+
new_parts.append(part)
97109
continue
98110

99-
return user_message
111+
if modified:
112+
return types.Content(role=user_message.role, parts=new_parts)
113+
else:
114+
return None

src/google/adk/tools/load_artifacts_tool.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import json
18+
import logging
1819
from typing import Any
1920
from typing import TYPE_CHECKING
2021

@@ -27,6 +28,8 @@
2728
from ..models.llm_request import LlmRequest
2829
from .tool_context import ToolContext
2930

31+
logger = logging.getLogger('google_adk.' + __name__)
32+
3033

3134
class LoadArtifactsTool(BaseTool):
3235
"""A tool that loads the artifacts and adds them to the session."""
@@ -96,7 +99,18 @@ async def _append_artifacts_to_llm_request(
9699
if function_response and function_response.name == 'load_artifacts':
97100
artifact_names = function_response.response['artifact_names']
98101
for artifact_name in artifact_names:
102+
# Try session-scoped first (default behavior)
99103
artifact = await tool_context.load_artifact(artifact_name)
104+
105+
# If not found and name doesn't already have user: prefix,
106+
# try cross-session artifacts with user: prefix
107+
if artifact is None and not artifact_name.startswith('user:'):
108+
prefixed_name = f'user:{artifact_name}'
109+
artifact = await tool_context.load_artifact(prefixed_name)
110+
111+
if artifact is None:
112+
logger.warning('Artifact "%s" not found, skipping', artifact_name)
113+
continue
100114
llm_request.contents.append(
101115
types.Content(
102116
role='user',

tests/unittests/artifacts/test_artifact_service.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,48 @@ async def test_list_versions(service_type):
277277
)
278278

279279
assert response_versions == list(range(4))
280+
281+
282+
@pytest.mark.asyncio
283+
async def test_list_keys_preserves_user_prefix():
284+
"""Tests that list_artifact_keys preserves 'user:' prefix in returned names."""
285+
artifact_service = InMemoryArtifactService()
286+
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
287+
app_name = "app0"
288+
user_id = "user0"
289+
session_id = "123"
290+
291+
# Save artifacts with "user:" prefix (cross-session artifacts)
292+
await artifact_service.save_artifact(
293+
app_name=app_name,
294+
user_id=user_id,
295+
session_id=session_id,
296+
filename="user:document.pdf",
297+
artifact=artifact,
298+
)
299+
300+
await artifact_service.save_artifact(
301+
app_name=app_name,
302+
user_id=user_id,
303+
session_id=session_id,
304+
filename="user:image.png",
305+
artifact=artifact,
306+
)
307+
308+
# Save session-scoped artifact without prefix
309+
await artifact_service.save_artifact(
310+
app_name=app_name,
311+
user_id=user_id,
312+
session_id=session_id,
313+
filename="session_file.txt",
314+
artifact=artifact,
315+
)
316+
317+
# List artifacts should return names with "user:" prefix for user-scoped artifacts
318+
artifact_keys = await artifact_service.list_artifact_keys(
319+
app_name=app_name, user_id=user_id, session_id=session_id
320+
)
321+
322+
# Should contain prefixed names and session file
323+
expected_keys = ["user:document.pdf", "user:image.png", "session_file.txt"]
324+
assert sorted(artifact_keys) == sorted(expected_keys)

tests/unittests/plugins/test_save_files_as_artifacts.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def test_save_files_with_display_name(self):
5757
invocation_context=self.mock_context, user_message=user_message
5858
)
5959

60-
# Verify artifact was saved with correct filename
60+
# Verify artifact was saved with correct filename (session-scoped by default)
6161
self.mock_context.artifact_service.save_artifact.assert_called_once_with(
6262
app_name="test_app",
6363
user_id="test_user",
@@ -66,7 +66,7 @@ async def test_save_files_with_display_name(self):
6666
artifact=original_part,
6767
)
6868

69-
# Verify message was modified with placeholder
69+
# Verify message was modified with placeholder (clean name)
7070
assert result.parts[0].text == '[Uploaded Artifact: "test_document.pdf"]'
7171

7272
@pytest.mark.asyncio
@@ -85,7 +85,7 @@ async def test_save_files_without_display_name(self):
8585
invocation_context=self.mock_context, user_message=user_message
8686
)
8787

88-
# Verify artifact was saved with generated filename
88+
# Verify artifact was saved with generated filename (session-scoped by default)
8989
expected_filename = "artifact_test_invocation_123_0"
9090
self.mock_context.artifact_service.save_artifact.assert_called_once_with(
9191
app_name="test_app",
@@ -95,8 +95,12 @@ async def test_save_files_without_display_name(self):
9595
artifact=original_part,
9696
)
9797

98-
# Verify message was modified with generated filename
99-
assert result.parts[0].text == f'[Uploaded Artifact: "{expected_filename}"]'
98+
# Verify message was modified with generated filename (clean name)
99+
generated_display_name = "artifact_test_invocation_123_0"
100+
assert (
101+
result.parts[0].text
102+
== f'[Uploaded Artifact: "{generated_display_name}"]'
103+
)
100104

101105
@pytest.mark.asyncio
102106
async def test_multiple_files_in_message(self):
@@ -138,7 +142,7 @@ async def test_multiple_files_in_message(self):
138142
)
139143
assert second_call[1]["filename"] == "file2.jpg"
140144

141-
# Verify message parts were modified correctly
145+
# Verify message parts were modified correctly (clean names)
142146
assert result.parts[0].text == '[Uploaded Artifact: "file1.txt"]'
143147
assert result.parts[1].text == "Some text between files" # Unchanged
144148
assert result.parts[2].text == '[Uploaded Artifact: "file2.jpg"]'
@@ -174,9 +178,8 @@ async def test_no_parts_in_message(self):
174178
invocation_context=self.mock_context, user_message=user_message
175179
)
176180

177-
# Should return original message unchanged
178-
assert result == user_message
179-
assert result.parts == []
181+
# Should return None to proceed with original message
182+
assert result is None
180183

181184
# Should not try to save any artifacts
182185
self.mock_context.artifact_service.save_artifact.assert_not_called()
@@ -193,10 +196,8 @@ async def test_parts_without_inline_data(self):
193196
invocation_context=self.mock_context, user_message=user_message
194197
)
195198

196-
# Should return original message unchanged
197-
assert result == user_message
198-
assert result.parts[0].text == "Hello world"
199-
assert result.parts[1].text == "No files here"
199+
# Should return None to proceed with original message
200+
assert result is None
200201

201202
# Should not try to save any artifacts
202203
self.mock_context.artifact_service.save_artifact.assert_not_called()
@@ -221,9 +222,8 @@ async def test_save_artifact_failure(self):
221222
invocation_context=self.mock_context, user_message=user_message
222223
)
223224

224-
# Should preserve original part when saving fails
225-
assert result.parts[0] == original_part
226-
assert result.parts[0].inline_data == inline_data
225+
# Should return None when saving fails (no modifications made)
226+
assert result is None
227227

228228
@pytest.mark.asyncio
229229
async def test_mixed_success_and_failure(self):
@@ -264,7 +264,7 @@ def mock_save_artifact(*_args, **_kwargs):
264264
invocation_context=self.mock_context, user_message=user_message
265265
)
266266

267-
# First file should be replaced with placeholder
267+
# First file should be replaced with placeholder (clean name)
268268
assert result.parts[0].text == '[Uploaded Artifact: "success.pdf"]'
269269

270270
# Second file should remain unchanged due to failure
@@ -287,7 +287,7 @@ async def test_placeholder_text_format(self):
287287
invocation_context=self.mock_context, user_message=user_message
288288
)
289289

290-
# Verify exact format of placeholder text
290+
# Verify exact format of placeholder text (clean name)
291291
expected_text = '[Uploaded Artifact: "test file with spaces.docx"]'
292292
assert result.parts[0].text == expected_text
293293

0 commit comments

Comments
 (0)