|
16 | 16 |
|
17 | 17 | from google.adk.agents.base_agent import BaseAgent |
18 | 18 | from google.adk.agents.invocation_context import InvocationContext |
| 19 | +from google.adk.apps import ResumabilityConfig |
19 | 20 | from google.adk.events.event import Event |
20 | 21 | from google.adk.sessions.base_session_service import BaseSessionService |
21 | 22 | from google.adk.sessions.session import Session |
| 23 | +from google.genai.types import FunctionCall |
| 24 | +from google.genai.types import Part |
22 | 25 | import pytest |
23 | 26 |
|
| 27 | +from .. import testing_utils |
| 28 | + |
24 | 29 |
|
25 | 30 | class TestInvocationContext: |
26 | 31 | """Test suite for InvocationContext.""" |
@@ -117,3 +122,87 @@ def test_get_events_with_no_matching_events(self, mock_invocation_context): |
117 | 122 | current_branch=True, |
118 | 123 | ) |
119 | 124 | assert not events |
| 125 | + |
| 126 | + |
| 127 | +class TestInvocationContextWithAppResumablity: |
| 128 | + """Test suite for InvocationContext regarding app resumability.""" |
| 129 | + |
| 130 | + @pytest.fixture |
| 131 | + def long_running_function_call(self) -> FunctionCall: |
| 132 | + """A long running function call.""" |
| 133 | + return FunctionCall( |
| 134 | + id='tool_call_id_1', |
| 135 | + name='long_running_function_call', |
| 136 | + args={}, |
| 137 | + ) |
| 138 | + |
| 139 | + @pytest.fixture |
| 140 | + def event_to_pause(self, long_running_function_call) -> Event: |
| 141 | + """An event with a long running function call.""" |
| 142 | + return Event( |
| 143 | + invocation_id='inv_1', |
| 144 | + author='agent', |
| 145 | + content=testing_utils.ModelContent( |
| 146 | + [Part(function_call=long_running_function_call)] |
| 147 | + ), |
| 148 | + long_running_tool_ids=[long_running_function_call.id], |
| 149 | + ) |
| 150 | + |
| 151 | + def _create_test_invocation_context( |
| 152 | + self, resumability_config |
| 153 | + ) -> InvocationContext: |
| 154 | + """Create a mock invocation context for testing.""" |
| 155 | + ctx = InvocationContext( |
| 156 | + session_service=Mock(spec=BaseSessionService), |
| 157 | + agent=Mock(spec=BaseAgent), |
| 158 | + invocation_id='inv_1', |
| 159 | + session=Mock(spec=Session), |
| 160 | + resumability_config=resumability_config, |
| 161 | + ) |
| 162 | + return ctx |
| 163 | + |
| 164 | + def test_should_pause_invocation_with_resumable_app(self, event_to_pause): |
| 165 | + """Tests should_pause_invocation with a resumable app.""" |
| 166 | + mock_invocation_context = self._create_test_invocation_context( |
| 167 | + ResumabilityConfig(is_resumable=True) |
| 168 | + ) |
| 169 | + |
| 170 | + assert mock_invocation_context.should_pause_invocation(event_to_pause) |
| 171 | + |
| 172 | + def test_should_not_pause_invocation_with_non_resumable_app( |
| 173 | + self, event_to_pause |
| 174 | + ): |
| 175 | + """Tests should_pause_invocation with a non-resumable app.""" |
| 176 | + invocation_context = self._create_test_invocation_context( |
| 177 | + ResumabilityConfig(is_resumable=False) |
| 178 | + ) |
| 179 | + |
| 180 | + assert not invocation_context.should_pause_invocation(event_to_pause) |
| 181 | + |
| 182 | + def test_should_not_pause_invocation_with_no_long_running_tool_ids( |
| 183 | + self, event_to_pause |
| 184 | + ): |
| 185 | + """Tests should_pause_invocation with no long running tools.""" |
| 186 | + invocation_context = self._create_test_invocation_context( |
| 187 | + ResumabilityConfig(is_resumable=True) |
| 188 | + ) |
| 189 | + nonpausable_event = event_to_pause.model_copy( |
| 190 | + update={'long_running_tool_ids': []} |
| 191 | + ) |
| 192 | + |
| 193 | + assert not invocation_context.should_pause_invocation(nonpausable_event) |
| 194 | + |
| 195 | + def test_should_not_pause_invocation_with_no_function_calls( |
| 196 | + self, event_to_pause |
| 197 | + ): |
| 198 | + """Tests should_pause_invocation with a non-model event.""" |
| 199 | + mock_invocation_context = self._create_test_invocation_context( |
| 200 | + ResumabilityConfig(is_resumable=True) |
| 201 | + ) |
| 202 | + nonpausable_event = event_to_pause.model_copy( |
| 203 | + update={'content': testing_utils.UserContent('test text part')} |
| 204 | + ) |
| 205 | + |
| 206 | + assert not mock_invocation_context.should_pause_invocation( |
| 207 | + nonpausable_event |
| 208 | + ) |
0 commit comments