Skip to content

Commit 13a95c4

Browse files
XinranTangcopybara-github
authored andcommitted
feat: Add get_events util function in invocation_context
PiperOrigin-RevId: 809111315
1 parent f157b2e commit 13a95c4

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525

2626
from ..artifacts.base_artifact_service import BaseArtifactService
2727
from ..auth.credential_service.base_credential_service import BaseCredentialService
28+
from ..events.event import Event
2829
from ..memory.base_memory_service import BaseMemoryService
2930
from ..plugins.plugin_manager import PluginManager
3031
from ..sessions.base_session_service import BaseSessionService
3132
from ..sessions.session import Session
33+
from ..utils.feature_decorator import working_in_progress
3234
from .active_streaming_tool import ActiveStreamingTool
3335
from .base_agent import BaseAgent
3436
from .live_request_queue import LiveRequestQueue
@@ -215,6 +217,33 @@ def app_name(self) -> str:
215217
def user_id(self) -> str:
216218
return self.session.user_id
217219

220+
@working_in_progress("incomplete feature, don't use yet")
221+
def get_events(
222+
self,
223+
current_invocation: bool = False,
224+
current_branch: bool = False,
225+
) -> list[Event]:
226+
"""Returns the events from the current session.
227+
228+
Args:
229+
current_invocation: Whether to filter the events by the current
230+
invocation.
231+
current_branch: Whether to filter the events by the current branch.
232+
233+
Returns:
234+
A list of events from the current session.
235+
"""
236+
results = self.session.events
237+
if current_invocation:
238+
results = [
239+
event
240+
for event in results
241+
if event.invocation_id == self.invocation_id
242+
]
243+
if current_branch:
244+
results = [event for event in results if event.branch == self.branch]
245+
return results
246+
218247

219248
def new_invocation_context_id() -> str:
220249
return "e-" + str(uuid.uuid4())
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import Mock
16+
17+
from google.adk.agents.base_agent import BaseAgent
18+
from google.adk.agents.invocation_context import InvocationContext
19+
from google.adk.events.event import Event
20+
from google.adk.sessions.base_session_service import BaseSessionService
21+
from google.adk.sessions.session import Session
22+
import pytest
23+
24+
25+
class TestInvocationContext:
26+
"""Test suite for InvocationContext."""
27+
28+
@pytest.fixture
29+
def mock_events(self):
30+
"""Create mock events for testing."""
31+
event1 = Mock(spec=Event)
32+
event1.invocation_id = 'inv_1'
33+
event1.branch = 'agent_1'
34+
35+
event2 = Mock(spec=Event)
36+
event2.invocation_id = 'inv_1'
37+
event2.branch = 'agent_2'
38+
39+
event3 = Mock(spec=Event)
40+
event3.invocation_id = 'inv_2'
41+
event3.branch = 'agent_1'
42+
43+
event4 = Mock(spec=Event)
44+
event4.invocation_id = 'inv_2'
45+
event4.branch = 'agent_2'
46+
47+
return [event1, event2, event3, event4]
48+
49+
@pytest.fixture
50+
def mock_invocation_context(self, mock_events):
51+
"""Create a mock invocation context for testing."""
52+
ctx = InvocationContext(
53+
session_service=Mock(spec=BaseSessionService),
54+
agent=Mock(spec=BaseAgent),
55+
invocation_id='inv_1',
56+
branch='agent_1',
57+
session=Mock(spec=Session, events=mock_events),
58+
)
59+
return ctx
60+
61+
def test_get_events_returns_all_events_by_default(
62+
self, mock_invocation_context, mock_events
63+
):
64+
"""Tests that get_events returns all events when no filters are applied."""
65+
events = mock_invocation_context.get_events()
66+
assert events == mock_events
67+
68+
def test_get_events_filters_by_current_invocation(
69+
self, mock_invocation_context, mock_events
70+
):
71+
"""Tests that get_events correctly filters by the current invocation."""
72+
event1, event2, _, _ = mock_events
73+
events = mock_invocation_context.get_events(current_invocation=True)
74+
assert events == [event1, event2]
75+
76+
def test_get_events_filters_by_current_branch(
77+
self, mock_invocation_context, mock_events
78+
):
79+
"""Tests that get_events correctly filters by the current branch."""
80+
event1, _, event3, _ = mock_events
81+
events = mock_invocation_context.get_events(current_branch=True)
82+
assert events == [event1, event3]
83+
84+
def test_get_events_filters_by_invocation_and_branch(
85+
self, mock_invocation_context, mock_events
86+
):
87+
"""Tests that get_events filters by invocation and branch."""
88+
event1, _, _, _ = mock_events
89+
events = mock_invocation_context.get_events(
90+
current_invocation=True,
91+
current_branch=True,
92+
)
93+
assert events == [event1]
94+
95+
def test_get_events_with_no_events_in_session(self, mock_invocation_context):
96+
"""Tests get_events when the session has no events."""
97+
mock_invocation_context.session.events = []
98+
events = mock_invocation_context.get_events()
99+
assert not events
100+
101+
def test_get_events_with_no_matching_events(self, mock_invocation_context):
102+
"""Tests get_events when no events match the filters."""
103+
mock_invocation_context.invocation_id = 'inv_3'
104+
mock_invocation_context.branch = 'branch_C'
105+
106+
# Filter by invocation
107+
events = mock_invocation_context.get_events(current_invocation=True)
108+
assert not events
109+
110+
# Filter by branch
111+
events = mock_invocation_context.get_events(current_branch=True)
112+
assert not events
113+
114+
# Filter by both
115+
events = mock_invocation_context.get_events(
116+
current_invocation=True,
117+
current_branch=True,
118+
)
119+
assert not events

0 commit comments

Comments
 (0)