20
20
from google .adk .events .event import Event
21
21
from google .adk .sessions .base_session_service import BaseSessionService
22
22
from google .adk .sessions .session import Session
23
+ from google .genai .types import Content
23
24
from google .genai .types import FunctionCall
24
25
from google .genai .types import Part
25
26
import pytest
@@ -67,31 +68,31 @@ def test_get_events_returns_all_events_by_default(
67
68
self , mock_invocation_context , mock_events
68
69
):
69
70
"""Tests that get_events returns all events when no filters are applied."""
70
- events = mock_invocation_context .get_events ()
71
+ events = mock_invocation_context ._get_events ()
71
72
assert events == mock_events
72
73
73
74
def test_get_events_filters_by_current_invocation (
74
75
self , mock_invocation_context , mock_events
75
76
):
76
77
"""Tests that get_events correctly filters by the current invocation."""
77
78
event1 , event2 , _ , _ = mock_events
78
- events = mock_invocation_context .get_events (current_invocation = True )
79
+ events = mock_invocation_context ._get_events (current_invocation = True )
79
80
assert events == [event1 , event2 ]
80
81
81
82
def test_get_events_filters_by_current_branch (
82
83
self , mock_invocation_context , mock_events
83
84
):
84
85
"""Tests that get_events correctly filters by the current branch."""
85
86
event1 , _ , event3 , _ = mock_events
86
- events = mock_invocation_context .get_events (current_branch = True )
87
+ events = mock_invocation_context ._get_events (current_branch = True )
87
88
assert events == [event1 , event3 ]
88
89
89
90
def test_get_events_filters_by_invocation_and_branch (
90
91
self , mock_invocation_context , mock_events
91
92
):
92
93
"""Tests that get_events filters by invocation and branch."""
93
94
event1 , _ , _ , _ = mock_events
94
- events = mock_invocation_context .get_events (
95
+ events = mock_invocation_context ._get_events (
95
96
current_invocation = True ,
96
97
current_branch = True ,
97
98
)
@@ -100,7 +101,7 @@ def test_get_events_filters_by_invocation_and_branch(
100
101
def test_get_events_with_no_events_in_session (self , mock_invocation_context ):
101
102
"""Tests get_events when the session has no events."""
102
103
mock_invocation_context .session .events = []
103
- events = mock_invocation_context .get_events ()
104
+ events = mock_invocation_context ._get_events ()
104
105
assert not events
105
106
106
107
def test_get_events_with_no_matching_events (self , mock_invocation_context ):
@@ -109,15 +110,15 @@ def test_get_events_with_no_matching_events(self, mock_invocation_context):
109
110
mock_invocation_context .branch = 'branch_C'
110
111
111
112
# Filter by invocation
112
- events = mock_invocation_context .get_events (current_invocation = True )
113
+ events = mock_invocation_context ._get_events (current_invocation = True )
113
114
assert not events
114
115
115
116
# Filter by branch
116
- events = mock_invocation_context .get_events (current_branch = True )
117
+ events = mock_invocation_context ._get_events (current_branch = True )
117
118
assert not events
118
119
119
120
# Filter by both
120
- events = mock_invocation_context .get_events (
121
+ events = mock_invocation_context ._get_events (
121
122
current_invocation = True ,
122
123
current_branch = True ,
123
124
)
@@ -225,3 +226,114 @@ def test_is_resumable_no_config(self):
225
226
"""Tests that is_resumable is False when no resumability config is set."""
226
227
invocation_context = self ._create_test_invocation_context (None )
227
228
assert not invocation_context .is_resumable
229
+
230
+
231
+ class TestFindMatchingFunctionCall :
232
+ """Test suite for find_matching_function_call."""
233
+
234
+ @pytest .fixture
235
+ def test_invocation_context (self ):
236
+ """Create a mock invocation context for testing."""
237
+
238
+ def _create_invocation_context (events ):
239
+ return InvocationContext (
240
+ session_service = Mock (spec = BaseSessionService ),
241
+ agent = Mock (spec = BaseAgent , name = 'agent' ),
242
+ invocation_id = 'inv_1' ,
243
+ session = Mock (spec = Session , events = events ),
244
+ )
245
+
246
+ return _create_invocation_context
247
+
248
+ def test_find_matching_function_call_found (self , test_invocation_context ):
249
+ """Tests that a matching function call is found."""
250
+ fc = Part .from_function_call (name = 'some_tool' , args = {})
251
+ fc .function_call .id = 'test_function_call_id'
252
+ fc_event = Event (
253
+ invocation_id = 'inv_1' ,
254
+ author = 'agent' ,
255
+ content = testing_utils .ModelContent ([fc ]),
256
+ )
257
+ fr = Part .from_function_response (
258
+ name = 'some_tool' , response = {'result' : 'ok' }
259
+ )
260
+ fr .function_response .id = 'test_function_call_id'
261
+ fr_event = Event (
262
+ invocation_id = 'inv_1' ,
263
+ author = 'agent' ,
264
+ content = Content (role = 'user' , parts = [fr ]),
265
+ )
266
+ invocation_context = test_invocation_context ([fc_event , fr_event ])
267
+ matching_fc_event = invocation_context ._find_matching_function_call (
268
+ fr_event
269
+ )
270
+ assert testing_utils .simplify_content (
271
+ matching_fc_event .content
272
+ ) == testing_utils .simplify_content (fc_event .content )
273
+
274
+ def test_find_matching_function_call_not_found (self , test_invocation_context ):
275
+ """Tests that no matching function call is returned if id doesn't match."""
276
+ fc = Part .from_function_call (name = 'some_tool' , args = {})
277
+ fc .function_call .id = 'another_function_call_id'
278
+ fc_event = Event (
279
+ invocation_id = 'inv_1' ,
280
+ author = 'agent' ,
281
+ content = testing_utils .ModelContent ([fc ]),
282
+ )
283
+ fr = Part .from_function_response (
284
+ name = 'some_tool' , response = {'result' : 'ok' }
285
+ )
286
+ fr .function_response .id = 'test_function_call_id'
287
+ fr_event = Event (
288
+ invocation_id = 'inv_1' ,
289
+ author = 'agent' ,
290
+ content = Content (role = 'user' , parts = [fr ]),
291
+ )
292
+ invocation_context = test_invocation_context ([fc_event , fr_event ])
293
+ match = invocation_context ._find_matching_function_call (fr_event )
294
+ assert match is None
295
+
296
+ def test_find_matching_function_call_no_call_events (
297
+ self , test_invocation_context
298
+ ):
299
+ """Tests that no matching function call is returned if there are no call events."""
300
+ fr = Part .from_function_response (
301
+ name = 'some_tool' , response = {'result' : 'ok' }
302
+ )
303
+ fr .function_response .id = 'test_function_call_id'
304
+ fr_event = Event (
305
+ invocation_id = 'inv_1' ,
306
+ author = 'agent' ,
307
+ content = Content (role = 'user' , parts = [fr ]),
308
+ )
309
+ invocation_context = test_invocation_context ([fr_event ])
310
+ match = invocation_context ._find_matching_function_call (fr_event )
311
+ assert match is None
312
+
313
+ def test_find_matching_function_call_no_response_in_event (
314
+ self , test_invocation_context
315
+ ):
316
+ """Tests result is None if function_response_event has no function response."""
317
+ fr_event_no_fr = Event (
318
+ author = 'agent' ,
319
+ content = Content (role = 'user' , parts = [Part (text = 'user message' )]),
320
+ )
321
+ fc = Part .from_function_call (name = 'some_tool' , args = {})
322
+ fc .function_call .id = 'test_function_call_id'
323
+ fc_event = Event (
324
+ invocation_id = 'inv_1' ,
325
+ author = 'agent' ,
326
+ content = testing_utils .ModelContent ([fc ]),
327
+ )
328
+ fr = Part .from_function_response (
329
+ name = 'some_tool' , response = {'result' : 'ok' }
330
+ )
331
+ fr .function_response .id = 'test_function_call_id'
332
+ fr_event = Event (
333
+ invocation_id = 'inv_1' ,
334
+ author = 'agent' ,
335
+ content = Content (role = 'user' , parts = [Part (text = 'user message' )]),
336
+ )
337
+ invocation_context = test_invocation_context ([fc_event , fr_event ])
338
+ match = invocation_context ._find_matching_function_call (fr_event_no_fr )
339
+ assert match is None
0 commit comments