|
10 | 10 | from ag_ui.core import RunAgentInput, UserMessage
|
11 | 11 | from adk_middleware import ADKAgent
|
12 | 12 | from google.adk.agents import Agent
|
| 13 | +from google.genai import types |
13 | 14 |
|
14 | 15 |
|
15 | 16 | async def test_message_events():
|
@@ -87,6 +88,142 @@ async def test_message_events():
|
87 | 88 | return validate_message_event_pattern(start_count, end_count, content_count, text_message_events)
|
88 | 89 |
|
89 | 90 |
|
| 91 | +async def test_message_events_from_before_agent_callback(): |
| 92 | + """Test that we get proper message events with correct START/CONTENT/END patterns, |
| 93 | + even if we return the message from before_agent_callback. |
| 94 | + """ |
| 95 | + |
| 96 | + if not os.getenv("GOOGLE_API_KEY"): |
| 97 | + print("⚠️ GOOGLE_API_KEY not set - using mock test") |
| 98 | + return await test_with_mock() |
| 99 | + |
| 100 | + print("🧪 Testing with real Google ADK agent...") |
| 101 | + |
| 102 | + event_message = "This message was not generated." |
| 103 | + def return_predefined_message(callback_context): |
| 104 | + return types.Content( |
| 105 | + parts=[types.Part(text=event_message)], |
| 106 | + role="model" # Assign model role to the overriding response |
| 107 | + ) |
| 108 | + |
| 109 | + # Create real agent |
| 110 | + agent = Agent( |
| 111 | + name="test_agent", |
| 112 | + instruction="You are a helpful assistant. Keep responses brief.", |
| 113 | + before_agent_callback=return_predefined_message |
| 114 | + ) |
| 115 | + |
| 116 | + # Create middleware with direct agent embedding |
| 117 | + adk_agent = ADKAgent( |
| 118 | + adk_agent=agent, |
| 119 | + app_name="test_app", |
| 120 | + user_id="test_user", |
| 121 | + use_in_memory_services=True, |
| 122 | + ) |
| 123 | + |
| 124 | + # Test input |
| 125 | + test_input = RunAgentInput( |
| 126 | + thread_id="test_thread", |
| 127 | + run_id="test_run", |
| 128 | + messages=[ |
| 129 | + UserMessage( |
| 130 | + id="msg_1", |
| 131 | + role="user", |
| 132 | + content="Say hello in exactly 3 words." |
| 133 | + ) |
| 134 | + ], |
| 135 | + state={}, |
| 136 | + context=[], |
| 137 | + tools=[], |
| 138 | + forwarded_props={} |
| 139 | + ) |
| 140 | + |
| 141 | + print("🚀 Running test request...") |
| 142 | + |
| 143 | + events = [] |
| 144 | + text_message_events = [] |
| 145 | + |
| 146 | + try: |
| 147 | + async for event in adk_agent.run(test_input): |
| 148 | + events.append(event) |
| 149 | + event_type = str(event.type) |
| 150 | + print(f"📧 {event_type}") |
| 151 | + |
| 152 | + # Track text message events specifically |
| 153 | + if "TEXT_MESSAGE" in event_type: |
| 154 | + text_message_events.append(event_type) |
| 155 | + |
| 156 | + except Exception as e: |
| 157 | + print(f"❌ Error during test: {e}") |
| 158 | + return False |
| 159 | + |
| 160 | + print(f"\n📊 Results:") |
| 161 | + print(f" Total events: {len(events)}") |
| 162 | + print(f" Text message events: {text_message_events}") |
| 163 | + |
| 164 | + # Analyze message event patterns |
| 165 | + start_count = text_message_events.count("EventType.TEXT_MESSAGE_START") |
| 166 | + end_count = text_message_events.count("EventType.TEXT_MESSAGE_END") |
| 167 | + content_count = text_message_events.count("EventType.TEXT_MESSAGE_CONTENT") |
| 168 | + |
| 169 | + print(f" START events: {start_count}") |
| 170 | + print(f" END events: {end_count}") |
| 171 | + print(f" CONTENT events: {content_count}") |
| 172 | + |
| 173 | + pattern_is_valid = validate_message_event_pattern(start_count, end_count, content_count, text_message_events) |
| 174 | + if not pattern_is_valid: |
| 175 | + return False |
| 176 | + |
| 177 | + expected_text_events = [ |
| 178 | + { |
| 179 | + "type": "EventType.TEXT_MESSAGE_START", |
| 180 | + }, |
| 181 | + { |
| 182 | + "type": "EventType.TEXT_MESSAGE_CONTENT", |
| 183 | + "delta": event_message |
| 184 | + }, |
| 185 | + { |
| 186 | + "type": "EventType.TEXT_MESSAGE_END", |
| 187 | + } |
| 188 | + ] |
| 189 | + return validate_message_events(events, expected_text_events) |
| 190 | + |
| 191 | + |
| 192 | +def validate_message_events(events, expected_events): |
| 193 | + """Compare expected events by type and delta (if delta exists).""" |
| 194 | + # Filter events to only those specified in expected_events |
| 195 | + event_types_to_check = {expected["type"] for expected in expected_events} |
| 196 | + |
| 197 | + filtered_events = [] |
| 198 | + for event in events: |
| 199 | + event_type_str = f"EventType.{event.type.value}" |
| 200 | + if event_type_str in event_types_to_check: |
| 201 | + filtered_events.append(event) |
| 202 | + |
| 203 | + if len(filtered_events) != len(expected_events): |
| 204 | + print(f"❌ Event count mismatch: expected {len(expected_events)}, got {len(filtered_events)}") |
| 205 | + return False |
| 206 | + |
| 207 | + for i, (event, expected) in enumerate(zip(filtered_events, expected_events)): |
| 208 | + # Check event type |
| 209 | + event_type_str = f"EventType.{event.type.value}" |
| 210 | + if event_type_str != expected["type"]: |
| 211 | + print(f"❌ Event {i}: type mismatch - expected {expected['type']}, got {event_type_str}") |
| 212 | + return False |
| 213 | + |
| 214 | + # Check delta if specified |
| 215 | + if "delta" in expected: |
| 216 | + if not hasattr(event, 'delta'): |
| 217 | + print(f"❌ Event {i}: expected delta field but event has none") |
| 218 | + return False |
| 219 | + if event.delta != expected["delta"]: |
| 220 | + print(f"❌ Event {i}: delta mismatch - expected '{expected['delta']}', got '{event.delta}'") |
| 221 | + return False |
| 222 | + |
| 223 | + print("✅ All expected events validated successfully") |
| 224 | + return True |
| 225 | + |
| 226 | + |
90 | 227 | def validate_message_event_pattern(start_count, end_count, content_count, text_message_events):
|
91 | 228 | """Validate that message events follow proper patterns."""
|
92 | 229 |
|
@@ -319,6 +456,13 @@ async def test_text_message_events():
|
319 | 456 | assert result, "Text message events test failed"
|
320 | 457 |
|
321 | 458 |
|
| 459 | +@pytest.mark.asyncio |
| 460 | +async def test_text_message_events_from_before_agent_callback(): |
| 461 | + """Test that we get proper message events with correct START/CONTENT/END patterns.""" |
| 462 | + result = await test_message_events_from_before_agent_callback() |
| 463 | + assert result, "Text message events for before_agent_callback test failed" |
| 464 | + |
| 465 | + |
322 | 466 | @pytest.mark.asyncio
|
323 | 467 | async def test_message_event_edge_cases():
|
324 | 468 | """Test edge cases for message event patterns."""
|
|
0 commit comments