Skip to content

Commit a50cbf6

Browse files
Address PR review comments for tool result uniqueness
- Fix type safety: Use list[TextContent | ImageContent] instead of list[TextContent] with type ignore - Use model_copy() instead of model_dump() for observation merging - Add UUID suffix to merged event IDs to prevent potential collisions - Extract _group_observations_by_tool_call() helper to eliminate duplicate iteration between transform() and enforce() - Replace unittest mocks with real ObservationEvent instances in tests Co-authored-by: openhands <openhands@all-hands.dev>
1 parent a48436a commit a50cbf6

File tree

2 files changed

+48
-74
lines changed

2 files changed

+48
-74
lines changed

openhands-sdk/openhands/sdk/context/view/properties/tool_result_uniqueness.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
information for the LLM.
1111
"""
1212

13+
import uuid
1314
from collections import defaultdict
1415
from collections.abc import Sequence
1516

@@ -24,7 +25,7 @@
2425
ObservationEvent,
2526
ToolCallID,
2627
)
27-
from openhands.sdk.llm import TextContent
28+
from openhands.sdk.llm import ImageContent, TextContent
2829

2930

3031
def _create_merged_observation(
@@ -49,21 +50,18 @@ def _create_merged_observation(
4950

5051
# Create new content list with error context prepended
5152
original_content = list(obs_event.observation.content)
52-
merged_content: list[TextContent] = [TextContent(text=error_prefix)]
53-
merged_content.extend(original_content) # type: ignore[arg-type]
53+
merged_content: list[TextContent | ImageContent] = [TextContent(text=error_prefix)]
54+
merged_content.extend(original_content)
5455

55-
# Create a new observation with merged content
56-
# We need to preserve all fields from the original observation
57-
obs_data = obs_event.observation.model_dump()
58-
obs_data["content"] = merged_content
59-
60-
# Create the new observation using the same class as the original
61-
merged_observation = obs_event.observation.__class__(**obs_data)
56+
# Create a new observation with merged content using model_copy
57+
merged_observation = obs_event.observation.model_copy(
58+
update={"content": merged_content}
59+
)
6260

6361
# Create new ObservationEvent with a unique ID
64-
# ID format: "{original_id}-merged" to ensure uniqueness
62+
# ID format: "{original_id}-merged-{uuid}" to ensure uniqueness
6563
return ObservationEvent(
66-
id=f"{obs_event.id}-merged",
64+
id=f"{obs_event.id}-merged-{uuid.uuid4().hex[:8]}",
6765
tool_name=obs_event.tool_name,
6866
tool_call_id=obs_event.tool_call_id,
6967
observation=merged_observation,
@@ -72,6 +70,26 @@ def _create_merged_observation(
7270
)
7371

7472

73+
def _group_observations_by_tool_call(
74+
events: list[LLMConvertibleEvent],
75+
) -> dict[ToolCallID, list[ObservationBaseEvent]]:
76+
"""Group observations by their tool_call_id.
77+
78+
Args:
79+
events: The list of events to process.
80+
81+
Returns:
82+
A mapping from tool_call_id to list of observations with that ID.
83+
"""
84+
observations_by_tool_call: dict[ToolCallID, list[ObservationBaseEvent]] = (
85+
defaultdict(list)
86+
)
87+
for event in events:
88+
if isinstance(event, ObservationBaseEvent):
89+
observations_by_tool_call[event.tool_call_id].append(event)
90+
return observations_by_tool_call
91+
92+
7593
class ToolResultUniquenessProperty(ViewPropertyBase):
7694
"""Each tool_call_id must have exactly one tool result.
7795
@@ -92,15 +110,9 @@ def transform(
92110
(typically from a restart scenario), merge the error context into the
93111
observation so the LLM has full context about what happened.
94112
"""
95-
# Group observations by tool_call_id
96-
observations_by_tool_call: dict[ToolCallID, list[ObservationBaseEvent]] = (
97-
defaultdict(list)
113+
observations_by_tool_call = _group_observations_by_tool_call(
114+
current_view_events
98115
)
99-
100-
for event in current_view_events:
101-
if isinstance(event, ObservationBaseEvent):
102-
observations_by_tool_call[event.tool_call_id].append(event)
103-
104116
transforms: dict[EventID, LLMConvertibleEvent] = {}
105117

106118
for observations in observations_by_tool_call.values():
@@ -132,15 +144,9 @@ def enforce(
132144
removes the remaining duplicate events (the original AgentErrorEvents and
133145
any other duplicates).
134146
"""
135-
# Group observations by tool_call_id
136-
observations_by_tool_call: dict[ToolCallID, list[ObservationBaseEvent]] = (
137-
defaultdict(list)
147+
observations_by_tool_call = _group_observations_by_tool_call(
148+
current_view_events
138149
)
139-
140-
for event in current_view_events:
141-
if isinstance(event, ObservationBaseEvent):
142-
observations_by_tool_call[event.tool_call_id].append(event)
143-
144150
events_to_remove: set[EventID] = set()
145151

146152
for observations in observations_by_tool_call.values():

tests/sdk/context/view/properties/test_tool_result_uniqueness.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
- Duplicates are removed after merging
77
"""
88

9-
from unittest.mock import create_autospec
10-
119
from openhands.sdk.context.view.properties.tool_result_uniqueness import (
1210
ToolResultUniquenessProperty,
1311
_create_merged_observation,
1412
)
1513
from openhands.sdk.event.base import LLMConvertibleEvent
1614
from openhands.sdk.event.llm_convertible import (
17-
ActionEvent,
1815
AgentErrorEvent,
1916
ObservationEvent,
2017
UserRejectObservation,
@@ -63,7 +60,8 @@ def test_merges_single_error_into_observation(self) -> None:
6360

6461
merged = _create_merged_observation(obs, [error])
6562

66-
assert merged.id == "obs_1-merged"
63+
# ID format: {original_id}-merged-{uuid}
64+
assert merged.id.startswith("obs_1-merged-")
6765
assert merged.tool_call_id == "call_1"
6866
assert merged.tool_name == "terminal"
6967
# Check that error context is prepended
@@ -129,7 +127,8 @@ def test_transforms_when_error_and_observation_exist(self) -> None:
129127
assert "obs_1" in transforms
130128
merged = transforms["obs_1"]
131129
assert isinstance(merged, ObservationEvent)
132-
assert merged.id == "obs_1-merged"
130+
# ID format: {original_id}-merged-{uuid}
131+
assert merged.id.startswith("obs_1-merged-")
133132
# Error content should be in merged observation
134133
content_text = "".join(
135134
c.text
@@ -173,13 +172,8 @@ def test_empty_list(self) -> None:
173172

174173
def test_no_duplicates(self) -> None:
175174
"""Test enforce when there are no duplicate tool_call_ids."""
176-
obs1 = create_autospec(ObservationEvent, instance=True)
177-
obs1.tool_call_id = "call_1"
178-
obs1.id = "obs_1"
179-
180-
obs2 = create_autospec(ObservationEvent, instance=True)
181-
obs2.tool_call_id = "call_2"
182-
obs2.id = "obs_2"
175+
obs1 = create_observation_event("obs_1", "call_1", "Result 1")
176+
obs2 = create_observation_event("obs_2", "call_2", "Result 2")
183177

184178
events: list[LLMConvertibleEvent] = [
185179
message_event("Start"),
@@ -193,13 +187,8 @@ def test_no_duplicates(self) -> None:
193187

194188
def test_duplicate_observation_events(self) -> None:
195189
"""Test that duplicate ObservationEvents keep the later one."""
196-
obs1 = create_autospec(ObservationEvent, instance=True)
197-
obs1.tool_call_id = "call_1"
198-
obs1.id = "obs_1"
199-
200-
obs2 = create_autospec(ObservationEvent, instance=True)
201-
obs2.tool_call_id = "call_1" # Same tool_call_id!
202-
obs2.id = "obs_2"
190+
obs1 = create_observation_event("obs_1", "call_1", "First result")
191+
obs2 = create_observation_event("obs_2", "call_1", "Second result")
203192

204193
events: list[LLMConvertibleEvent] = [obs1, obs2]
205194

@@ -216,10 +205,7 @@ def test_observation_event_preferred_over_agent_error(self) -> None:
216205
tool_call_id="call_1",
217206
error="Restart occurred while tool was running",
218207
)
219-
220-
obs = create_autospec(ObservationEvent, instance=True)
221-
obs.tool_call_id = "call_1" # Same tool_call_id!
222-
obs.id = "obs_1"
208+
obs = create_observation_event("obs_1", "call_1", "Actual result")
223209

224210
events: list[LLMConvertibleEvent] = [agent_error, obs]
225211

@@ -229,24 +215,15 @@ def test_observation_event_preferred_over_agent_error(self) -> None:
229215

230216
def test_agent_error_before_observation_event(self) -> None:
231217
"""Test AgentErrorEvent followed by ObservationEvent (restart scenario)."""
232-
action = create_autospec(ActionEvent, instance=True)
233-
action.tool_call_id = "call_1"
234-
action.id = "action_1"
235-
action.llm_response_id = "response_1"
236-
237218
agent_error = AgentErrorEvent(
238219
tool_name="terminal",
239220
tool_call_id="call_1",
240221
error="A restart occurred while this tool was in progress.",
241222
)
242-
243-
obs = create_autospec(ObservationEvent, instance=True)
244-
obs.tool_call_id = "call_1"
245-
obs.id = "obs_1"
223+
obs = create_observation_event("obs_1", "call_1", "Actual result")
246224

247225
events: list[LLMConvertibleEvent] = [
248226
message_event("User message"),
249-
action,
250227
agent_error,
251228
obs, # Actual result arrives later
252229
]
@@ -283,10 +260,7 @@ def test_user_reject_observation_handling(self) -> None:
283260
action_id="action_1",
284261
rejection_reason="User rejected",
285262
)
286-
287-
obs = create_autospec(ObservationEvent, instance=True)
288-
obs.tool_call_id = "call_1"
289-
obs.id = "obs_1"
263+
obs = create_observation_event("obs_1", "call_1", "Actual result")
290264

291265
events: list[LLMConvertibleEvent] = [reject, obs]
292266

@@ -302,14 +276,10 @@ def test_mixed_scenario_multiple_tool_calls(self) -> None:
302276
tool_call_id="call_1",
303277
error="Restart error",
304278
)
305-
obs1 = create_autospec(ObservationEvent, instance=True)
306-
obs1.tool_call_id = "call_1"
307-
obs1.id = "obs_1"
279+
obs1 = create_observation_event("obs_1", "call_1", "Result 1")
308280

309281
# Tool call 2: single observation (no duplicate)
310-
obs2 = create_autospec(ObservationEvent, instance=True)
311-
obs2.tool_call_id = "call_2"
312-
obs2.id = "obs_2"
282+
obs2 = create_observation_event("obs_2", "call_2", "Result 2")
313283

314284
# Tool call 3: single error (no duplicate)
315285
error3 = AgentErrorEvent(
@@ -340,9 +310,7 @@ def setup_method(self) -> None:
340310

341311
def test_complete_indices_returned(self) -> None:
342312
"""Test that manipulation indices are complete (no restrictions)."""
343-
obs = create_autospec(ObservationEvent, instance=True)
344-
obs.tool_call_id = "call_1"
345-
obs.id = "obs_1"
313+
obs = create_observation_event("obs_1", "call_1", "Result")
346314

347315
events: list[LLMConvertibleEvent] = [
348316
message_event("Start"),

0 commit comments

Comments
 (0)