Skip to content

Commit 07f4c8a

Browse files
moonbox3CopilotCopilot
authored
Python: Expose forwardedProps to agents and tools via session metadata (#5264)
* Expose forwarded_props to agents and tools via session metadata (#5239) Include forwarded_props from AG-UI request input_data in session.metadata (agent runner) and function_invocation_kwargs (workflow runner) so that agents, tools, and workflow executors can access request-level metadata such as invocation source flags from CopilotKit. - Add forwarded_props to base_metadata in _agent_run.py when present - Add 'forwarded_props' to AG_UI_INTERNAL_METADATA_KEYS to filter it from LLM-bound client metadata - Extract forwarded_props in _workflow_run.py and pass via function_invocation_kwargs to workflow.run() - Accept both snake_case and camelCase keys (forwarded_props/forwardedProps) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(ag-ui): pass stream=True as literal to satisfy pyright overload resolution (#5239) The previous fix passed stream=True via **kwargs dict, which prevented pyright from resolving the Workflow.run() overload to the streaming variant. Pass stream=True as an explicit keyword argument so pyright can correctly infer the ResponseStream return type. Also remove unused pytest import in test file. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address PR review feedback for forwarded_props (#5239) - Use key-presence checks instead of truthiness for forwarded_props so empty dict {} is forwarded correctly - Gate function_invocation_kwargs on workflow.run() signature inspection to avoid TypeError for workflows without **kwargs - Change _build_safe_metadata to drop (with warning) keys whose serialized values exceed 512 chars instead of truncating into invalid JSON - Rewrite metadata tests to exercise _build_safe_metadata directly with JSON-decodability and truncation assertions - Add workflow tests for empty dict forwarded_props, stream=True assertion, and signature-gated kwarg dropping Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * test: add stream=True assertions to CapturingWorkflow tests (#5239) Guard against accidental removal of the explicit stream=True kwarg in all forwarded_props CapturingWorkflow test cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #5239: Python: Expose forwardedProps to agents and tools via session metadata --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 04aaf0c commit 07f4c8a

5 files changed

Lines changed: 339 additions & 13 deletions

File tree

python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,36 @@
6969
logger = logging.getLogger(__name__)
7070

7171
# Keys that are internal to AG-UI orchestration and should not be passed to chat clients
72-
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state"}
72+
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state", "forwarded_props"}
7373

7474

7575
def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]:
76-
"""Build metadata dict with truncated string values for Azure compatibility.
76+
"""Build metadata dict with string values for Azure compatibility.
7777
78-
Azure has a 512 character limit per metadata value.
78+
Azure has a 512 character limit per metadata value. String values that
79+
already fit are kept as-is. Non-string values are JSON-serialized. If the
80+
resulting string exceeds 512 characters the key is **dropped** (with a
81+
warning) instead of truncated, because truncation can produce invalid JSON
82+
that downstream consumers cannot decode.
7983
8084
Args:
8185
thread_metadata: Raw metadata dict
8286
8387
Returns:
84-
Metadata with string values truncated to 512 chars
88+
Metadata with safe string values (each <= 512 chars)
8589
"""
8690
if not thread_metadata:
8791
return {}
8892
safe_metadata: dict[str, Any] = {}
8993
for key, value in thread_metadata.items():
9094
value_str = value if isinstance(value, str) else json.dumps(value)
9195
if len(value_str) > 512:
92-
value_str = value_str[:512]
96+
logger.warning(
97+
"Dropping metadata key %r: serialized value is %d chars (limit 512)",
98+
key,
99+
len(value_str),
100+
)
101+
continue
93102
safe_metadata[key] = value_str
94103
return safe_metadata
95104

@@ -790,6 +799,10 @@ async def run_agent_stream(
790799
"ag_ui_thread_id": thread_id,
791800
"ag_ui_run_id": run_id,
792801
}
802+
if "forwarded_props" in input_data:
803+
base_metadata["forwarded_props"] = input_data["forwarded_props"]
804+
elif "forwardedProps" in input_data:
805+
base_metadata["forwarded_props"] = input_data["forwardedProps"]
793806
if flow.current_state:
794807
base_metadata["current_state"] = flow.current_state
795808
session.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined]

python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
import inspect
78
import json
89
import logging
910
import uuid
@@ -581,11 +582,33 @@ def _drain_open_message() -> list[TextMessageEndEvent]:
581582
flow.accumulated_text = ""
582583
return [TextMessageEndEvent(message_id=current_message_id)]
583584

585+
fwd_kwargs: dict[str, Any] = {}
586+
if "forwarded_props" in input_data:
587+
forwarded_props = input_data["forwarded_props"]
588+
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}
589+
elif "forwardedProps" in input_data:
590+
forwarded_props = input_data["forwardedProps"]
591+
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}
592+
593+
# Only pass function_invocation_kwargs if the workflow.run signature accepts it
594+
if fwd_kwargs:
595+
try:
596+
sig = inspect.signature(workflow.run)
597+
params = sig.parameters
598+
accepts_fwd = "function_invocation_kwargs" in params or any(
599+
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
600+
)
601+
except (ValueError, TypeError):
602+
accepts_fwd = False
603+
if not accepts_fwd:
604+
logger.debug("workflow.run() does not accept function_invocation_kwargs; dropping forwarded_props")
605+
fwd_kwargs = {}
606+
584607
try:
585608
if responses:
586-
event_stream = workflow.run(responses=responses, stream=True)
609+
event_stream = workflow.run(responses=responses, stream=True, **fwd_kwargs)
587610
else:
588-
event_stream = workflow.run(message=messages, stream=True)
611+
event_stream = workflow.run(message=messages, stream=True, **fwd_kwargs)
589612

590613
async for event in event_stream:
591614
event_type = getattr(event, "type", None)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
"""Tests for forwarded_props inclusion in AG-UI session metadata."""
4+
5+
import json
6+
from typing import Any
7+
8+
from agent_framework_ag_ui._agent_run import AG_UI_INTERNAL_METADATA_KEYS, _build_safe_metadata
9+
10+
11+
class TestForwardedPropsInSessionMetadata:
12+
"""Verify that forwarded_props is surfaced in session metadata and filtered from LLM metadata."""
13+
14+
def test_forwarded_props_in_internal_metadata_keys(self):
15+
"""forwarded_props is listed in AG_UI_INTERNAL_METADATA_KEYS to prevent LLM leakage."""
16+
assert "forwarded_props" in AG_UI_INTERNAL_METADATA_KEYS
17+
18+
def test_forwarded_props_filtered_from_client_metadata(self):
19+
"""forwarded_props is filtered out when building LLM-bound client metadata."""
20+
session_metadata: dict[str, Any] = {
21+
"ag_ui_thread_id": "t1",
22+
"ag_ui_run_id": "r1",
23+
"forwarded_props": '{"custom_flag": true}',
24+
}
25+
26+
client_metadata = {k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS}
27+
28+
assert "forwarded_props" not in client_metadata
29+
assert "ag_ui_thread_id" not in client_metadata
30+
31+
32+
class TestBuildSafeMetadata:
33+
"""Verify _build_safe_metadata handles various value types correctly."""
34+
35+
def test_string_value_unchanged(self):
36+
result = _build_safe_metadata({"key": "hello"})
37+
assert result == {"key": "hello"}
38+
39+
def test_dict_value_serialized_to_json(self):
40+
result = _build_safe_metadata({"fp": {"flag": True, "source": "frontend"}})
41+
assert "fp" in result
42+
assert isinstance(result["fp"], str)
43+
# Must be valid, decodable JSON
44+
decoded = json.loads(result["fp"])
45+
assert decoded == {"flag": True, "source": "frontend"}
46+
47+
def test_empty_dict_serialized_to_json(self):
48+
result = _build_safe_metadata({"fp": {}})
49+
assert result["fp"] == "{}"
50+
assert json.loads(result["fp"]) == {}
51+
52+
def test_value_within_limit_kept(self):
53+
value = "x" * 512
54+
result = _build_safe_metadata({"key": value})
55+
assert result["key"] == value
56+
57+
def test_value_exceeding_limit_dropped(self):
58+
"""Values exceeding 512 chars are dropped entirely (not truncated)."""
59+
value = "x" * 513
60+
result = _build_safe_metadata({"key": value})
61+
assert "key" not in result
62+
63+
def test_json_value_exceeding_limit_dropped(self):
64+
"""JSON-serialized dict exceeding 512 chars is dropped, not truncated into invalid JSON."""
65+
big_dict = {f"key_{i}": "v" * 100 for i in range(50)}
66+
result = _build_safe_metadata({"forwarded_props": big_dict})
67+
assert "forwarded_props" not in result
68+
69+
def test_other_keys_preserved_when_one_dropped(self):
70+
"""Dropping one oversized key does not affect other keys."""
71+
result = _build_safe_metadata(
72+
{
73+
"small": "ok",
74+
"big": "x" * 600,
75+
}
76+
)
77+
assert result == {"small": "ok"}
78+
79+
def test_none_input_returns_empty(self):
80+
assert _build_safe_metadata(None) == {}
81+
82+
def test_empty_input_returns_empty(self):
83+
assert _build_safe_metadata({}) == {}

python/packages/ag-ui/tests/ag_ui/test_run.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def test_short_string_values(self):
6363
result = _build_safe_metadata(metadata)
6464
assert result == metadata
6565

66-
def test_truncates_long_strings(self):
67-
"""Truncates strings over 512 chars."""
66+
def test_drops_long_strings(self):
67+
"""Drops strings over 512 chars instead of truncating."""
6868
long_value = "x" * 1000
6969
metadata = {"key": long_value}
7070
result = _build_safe_metadata(metadata)
71-
assert len(result["key"]) == 512
71+
assert "key" not in result
7272

7373
def test_serializes_non_strings(self):
7474
"""Serializes non-string values to JSON."""
@@ -77,12 +77,12 @@ def test_serializes_non_strings(self):
7777
assert result["count"] == "42"
7878
assert result["items"] == "[1, 2, 3]"
7979

80-
def test_truncates_serialized_values(self):
81-
"""Truncates serialized values over 512 chars."""
80+
def test_drops_oversized_serialized_values(self):
81+
"""Drops serialized values over 512 chars instead of truncating."""
8282
long_list = list(range(200))
8383
metadata = {"data": long_list}
8484
result = _build_safe_metadata(metadata)
85-
assert len(result["data"]) == 512
85+
assert "data" not in result
8686

8787

8888
class TestHasOnlyToolCalls:

0 commit comments

Comments
 (0)