Skip to content

Commit 67f3563

Browse files
committed
Enhance streaming and parsing functionality: add normalization for tool fragments, improve response text collection, and implement comprehensive unit tests for nested response handling.
1 parent 612fa50 commit 67f3563

File tree

5 files changed

+290
-33
lines changed

5 files changed

+290
-33
lines changed

src/compendiumscribe/research/execution/stream_events.py

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,93 @@
2525
]
2626

2727

28+
def _coerce_action_payload(payload: Any) -> dict[str, Any]:
29+
"""Normalise tool arguments into a mapping suitable for summaries."""
30+
31+
if isinstance(payload, dict):
32+
return payload
33+
34+
if isinstance(payload, str):
35+
try:
36+
parsed = json.loads(payload)
37+
except json.JSONDecodeError:
38+
return {"input": payload}
39+
if isinstance(parsed, dict):
40+
return parsed
41+
return {"value": parsed}
42+
43+
if payload is None:
44+
return {}
45+
46+
if isinstance(payload, (list, tuple, set)):
47+
return {"items": [item for item in payload]}
48+
49+
return {"value": payload}
50+
51+
52+
def normalize_tool_fragment(fragment: Any) -> dict[str, Any] | None:
53+
"""Coerce streaming tool fragments into the tool-call snapshot format."""
54+
55+
if not isinstance(fragment, dict):
56+
return None
57+
58+
candidate = dict(fragment)
59+
60+
embedded = candidate.pop("tool_call", None)
61+
if isinstance(embedded, dict):
62+
candidate = {**embedded, **candidate}
63+
64+
# Promote common identifier fields for streaming events.
65+
identifier = first_non_empty(
66+
coerce_optional_string(candidate.get(key))
67+
for key in ("id", "tool_call_id", "call_id")
68+
)
69+
if identifier:
70+
candidate["id"] = identifier
71+
72+
raw_type = coerce_optional_string(candidate.get("type"))
73+
tool_name = coerce_optional_string(candidate.get("tool_name"))
74+
tool_type = coerce_optional_string(candidate.get("tool_type"))
75+
76+
normalized_type: str | None = None
77+
if raw_type and raw_type.endswith("_call"):
78+
normalized_type = raw_type
79+
elif tool_name:
80+
normalized_type = f"{tool_name.replace('.', '_')}_call"
81+
elif tool_type:
82+
normalized_type = f"{tool_type.replace('.', '_')}_call"
83+
elif raw_type and raw_type.endswith("_delta"):
84+
normalized_type = f"{raw_type[: -len('_delta')]}_call"
85+
86+
if normalized_type:
87+
candidate["type"] = normalized_type
88+
89+
status = first_non_empty(
90+
coerce_optional_string(candidate.get(key))
91+
for key in ("status", "state")
92+
)
93+
if status:
94+
candidate["status"] = status
95+
96+
if "arguments" in candidate and "action" not in candidate:
97+
candidate["action"] = _coerce_action_payload(
98+
candidate.pop("arguments")
99+
)
100+
101+
if "call_arguments" in candidate and "action" not in candidate:
102+
candidate["action"] = _coerce_action_payload(
103+
candidate.pop("call_arguments")
104+
)
105+
106+
if "input" in candidate and "action" not in candidate:
107+
candidate["action"] = _coerce_action_payload(candidate.pop("input"))
108+
109+
if not coerce_optional_string(candidate.get("type")):
110+
return None
111+
112+
return candidate
113+
114+
28115
def emit_trace_updates_from_item(
29116
item: Any,
30117
*,
@@ -58,8 +145,12 @@ def accumulate_stream_tool_event(
58145
"""Merge incremental fragments for a single tool call while streaming."""
59146
fragments = collect_stream_fragments(item)
60147
event_id = first_non_empty(
61-
coerce_optional_string(fragment.get("id"))
148+
first_non_empty(
149+
coerce_optional_string(fragment.get(key))
150+
for key in ("id", "tool_call_id", "call_id")
151+
)
62152
for fragment in fragments
153+
if isinstance(fragment, dict)
63154
)
64155

65156
tool_events: dict[str, dict[str, Any]] = stream_state.setdefault(
@@ -97,13 +188,16 @@ def extract_stream_tool_fragment(
97188
existing: dict[str, Any],
98189
) -> dict[str, Any] | None:
99190
for fragment in fragments:
100-
item_type = coerce_optional_string(fragment.get("type"))
101-
if item_type and item_type.endswith("_call"):
102-
return fragment
191+
normalized_fragment = normalize_tool_fragment(fragment)
192+
if normalized_fragment is not None:
193+
return normalized_fragment
103194

104195
if (
105-
fragment.get("response") is not None
106-
or fragment.get("result") is not None
196+
isinstance(fragment, dict)
197+
and (
198+
fragment.get("response") is not None
199+
or fragment.get("result") is not None
200+
)
107201
):
108202
snapshot = dict(existing)
109203
if fragment.get("response") is not None:
@@ -114,7 +208,8 @@ def extract_stream_tool_fragment(
114208
snapshot.setdefault("result", {}).update(
115209
fragment.get("result") or {}
116210
)
117-
return snapshot
211+
normalized_snapshot = normalize_tool_fragment(snapshot)
212+
return normalized_snapshot or snapshot
118213

119214
return None
120215

@@ -123,31 +218,34 @@ def merge_tool_fragment(
123218
existing: dict[str, Any],
124219
fragment: dict[str, Any],
125220
) -> dict[str, Any]:
126-
merged = dict(existing)
221+
normalized_existing = normalize_tool_fragment(existing) or existing
222+
normalized_fragment = normalize_tool_fragment(fragment) or fragment
223+
224+
merged = dict(normalized_existing)
127225

128-
if fragment.get("type"):
129-
merged["type"] = fragment["type"]
130-
if fragment.get("id"):
131-
merged["id"] = fragment["id"]
132-
if fragment.get("status"):
133-
merged["status"] = fragment["status"]
226+
if normalized_fragment.get("type"):
227+
merged["type"] = normalized_fragment["type"]
228+
if normalized_fragment.get("id"):
229+
merged["id"] = normalized_fragment["id"]
230+
if normalized_fragment.get("status"):
231+
merged["status"] = normalized_fragment["status"]
134232

135-
action_fragment = fragment.get("action")
233+
action_fragment = normalized_fragment.get("action")
136234
if action_fragment is not None:
137235
merged_action = merged.get("action", {})
138236
merged_action = merge_action_payload(merged_action, action_fragment)
139237
merged["action"] = merged_action
140238

141-
if fragment.get("response") is not None:
239+
if normalized_fragment.get("response") is not None:
142240
merged["response"] = merge_response_payload(
143241
merged.get("response"),
144-
fragment["response"],
242+
normalized_fragment["response"],
145243
)
146244

147-
if fragment.get("result") is not None:
245+
if normalized_fragment.get("result") is not None:
148246
merged["result"] = merge_response_payload(
149247
merged.get("result"),
150-
fragment["result"],
248+
normalized_fragment["result"],
151249
)
152250

153251
return merged

src/compendiumscribe/research/execution/streaming.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,22 @@ def handle_stream_event(
135135
f"Deep research stream reported an error: {message}"
136136
)
137137

138+
if "tool_call" in normalized:
139+
fragment = (
140+
get_field(event, "item")
141+
or get_field(event, "delta")
142+
or get_field(event, "partial")
143+
or get_field(event, "data")
144+
)
145+
if fragment is not None:
146+
emit_trace_updates_from_item(
147+
fragment,
148+
config=config,
149+
seen_tokens=seen_trace_tokens,
150+
stream_state=stream_state,
151+
)
152+
return None
153+
138154
if normalized == "response.output_item.added":
139155
item = get_field(event, "item")
140156
if item is not None:

src/compendiumscribe/research/parsing.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,40 @@
88
from .utils import coerce_optional_string, get_field
99

1010

11+
def _iter_text_fragments(value: Any) -> list[str]:
12+
"""Recursively extract textual fragments from nested response payloads."""
13+
14+
fragments: list[str] = []
15+
16+
def visit(candidate: Any) -> None:
17+
if candidate is None:
18+
return
19+
20+
if isinstance(candidate, str):
21+
if candidate:
22+
fragments.append(candidate)
23+
return
24+
25+
if isinstance(candidate, (list, tuple, set)):
26+
for item in candidate:
27+
visit(item)
28+
return
29+
30+
if isinstance(candidate, dict):
31+
# Many response payloads nest text inside these keys.
32+
for key in ("text", "value", "content"):
33+
if key in candidate:
34+
visit(candidate[key])
35+
return
36+
37+
# Safety fallback: stringify scalars only (avoid object reprs).
38+
if isinstance(candidate, (int, float, bool)):
39+
fragments.append(str(candidate))
40+
41+
visit(value)
42+
return fragments
43+
44+
1145
def parse_deep_research_response(response: Any) -> dict[str, Any]:
1246
text_payload = collect_response_text(response)
1347
return decode_json_payload(text_payload)
@@ -16,7 +50,9 @@ def parse_deep_research_response(response: Any) -> dict[str, Any]:
1650
def collect_response_text(response: Any) -> str:
1751
output_text = get_field(response, "output_text")
1852
if output_text:
19-
return str(output_text).strip()
53+
fragments = _iter_text_fragments(output_text)
54+
if fragments:
55+
return "".join(fragments).strip()
2056

2157
output_items = get_field(response, "output")
2258
text_parts: list[str] = []
@@ -26,19 +62,13 @@ def collect_response_text(response: Any) -> str:
2662
item_type = coerce_optional_string(get_field(item, "type"))
2763
if item_type == "message":
2864
for content in get_field(item, "content") or []:
29-
text_value = coerce_optional_string(
30-
get_field(content, "text")
31-
)
32-
if not text_value:
33-
text_value = coerce_optional_string(
34-
get_field(content, "value")
35-
)
36-
if text_value:
37-
text_parts.append(text_value)
65+
fragments = _iter_text_fragments(content)
66+
if fragments:
67+
text_parts.append("".join(fragments))
3868
elif item_type == "output_text":
39-
text = coerce_optional_string(get_field(item, "text"))
40-
if text:
41-
text_parts.append(text)
69+
fragments = _iter_text_fragments(item)
70+
if fragments:
71+
text_parts.append("".join(fragments))
4272

4373
if text_parts:
4474
return "".join(text_parts).strip()

tests/research/execution/test_core.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,19 @@ def callback(update):
5151
events = [
5252
SimpleNamespace(type="response.created"),
5353
SimpleNamespace(
54-
type="response.output_item.added",
54+
type="response.tool_call.delta",
55+
delta={
56+
"type": "web_search_call",
57+
"id": "ws_1",
58+
"status": "in_progress",
59+
"action": {
60+
"type": "search",
61+
"query_delta": "double ",
62+
},
63+
},
64+
),
65+
SimpleNamespace(
66+
type="response.tool_call.completed",
5567
item={
5668
"type": "web_search_call",
5769
"id": "ws_1",
@@ -91,6 +103,62 @@ def create(self, **kwargs):
91103
)
92104

93105

106+
def test_execute_deep_research_streaming_handles_tool_name_payload():
107+
progress_updates = []
108+
109+
def callback(update):
110+
if update.phase == "trace":
111+
progress_updates.append(update.message)
112+
113+
final_response = SimpleNamespace(
114+
id="resp_456",
115+
status="completed",
116+
output_text='{"result": "ok"}',
117+
output=[],
118+
)
119+
120+
events = [
121+
SimpleNamespace(type="response.created"),
122+
SimpleNamespace(
123+
type="response.tool_call.delta",
124+
delta={
125+
"type": "tool_call_delta",
126+
"tool_name": "web.search",
127+
"tool_call_id": "ws_2",
128+
"arguments": {"query_delta": "history of"},
129+
},
130+
),
131+
SimpleNamespace(
132+
type="response.tool_call.completed",
133+
item={
134+
"tool_name": "web.search",
135+
"tool_call_id": "ws_2",
136+
"status": "completed",
137+
"arguments": {"query": "history of flutes"},
138+
},
139+
),
140+
SimpleNamespace(type="response.completed", response=final_response),
141+
]
142+
143+
class StubResponses:
144+
def __init__(self):
145+
self.calls: list[dict[str, object]] = []
146+
147+
def create(self, **kwargs):
148+
self.calls.append(kwargs)
149+
return StubStream(events, final_response)
150+
151+
responses = StubResponses()
152+
client = SimpleNamespace(responses=responses)
153+
154+
config = ResearchConfig(stream_progress=True, progress_callback=callback)
155+
156+
execute_deep_research(client, "prompt", config)
157+
158+
assert progress_updates
159+
assert any("history of flutes" in message for message in progress_updates)
160+
161+
94162
def test_execute_deep_research_streaming_raises_on_error():
95163
events = [
96164
SimpleNamespace(type="response.created"),

0 commit comments

Comments
 (0)