Skip to content

Commit 3b74cd0

Browse files
authored
fix(signals): Double-check directly provided session ids to ensure they exsist. (#42230)
1 parent 0d4fffc commit 3b74cd0

File tree

3 files changed

+80
-3
lines changed

3 files changed

+80
-3
lines changed

ee/hogai/chat_agent/session_summaries/nodes.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,20 @@ def _get_session_ids_with_filters(self, replay_filters: RecordingsQuery) -> list
355355
session_ids = [recording["session_id"] for recording in results.results]
356356
return session_ids if session_ids else None
357357

358+
def _validate_specific_session_ids(self, session_ids: list[str]) -> list[str] | None:
359+
"""Validate that specific session IDs exist in the database."""
360+
from posthog.session_recordings.queries.session_replay_events import SessionReplayEvents
361+
362+
replay_events = SessionReplayEvents()
363+
sessions_found, _, _ = replay_events.sessions_found_with_timestamps(
364+
session_ids=session_ids,
365+
team=self._node._team,
366+
)
367+
if not sessions_found:
368+
return None
369+
# Preserve the original order, filtering out invalid sessions
370+
return [sid for sid in session_ids if sid in sessions_found]
371+
358372
async def _generate_filter_query(self, plain_text_query: str, config: RunnableConfig) -> str:
359373
"""Generate a filter query for the user's summarization query to keep the search context clear"""
360374
messages = [
@@ -408,8 +422,13 @@ async def search_sessions(
408422
return self._node._create_error_response(self._node._base_error_instructions, state)
409423
# Use specific session IDs, if provided
410424
if state.specific_session_ids_to_summarize:
411-
# Return session ids right away to use in the next step
412-
return state.specific_session_ids_to_summarize
425+
# Validate that sessions exist before using them
426+
valid_session_ids = await database_sync_to_async(
427+
self._validate_specific_session_ids, thread_sensitive=False
428+
)(state.specific_session_ids_to_summarize)
429+
if not valid_session_ids:
430+
return None
431+
return valid_session_ids
413432
# Use current filters, if provided
414433
if state.should_use_current_filters:
415434
if not current_filters:

ee/hogai/chat_agent/session_summaries/test/__snapshots__/test_nodes.ambr

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,33 @@
192192
use_hive_partitioning=0
193193
'''
194194
# ---
195+
# name: TestSessionSummarizationNodeFilterGeneration.test_validate_specific_session_ids_all_invalid_returns_none
196+
'''
197+
198+
SELECT session_id,
199+
min(min_first_timestamp) as min_timestamp,
200+
max(max_last_timestamp) as max_timestamp,
201+
max(retention_period_days) as retention_period_days,
202+
dateTrunc('DAY', min_timestamp) + toIntervalDay(coalesce(retention_period_days, 21)) as expiry_time
203+
FROM session_replay_events PREWHERE team_id = 99999
204+
AND session_id IN ['nonexistent-session-1', 'nonexistent-session-2']
205+
AND min_first_timestamp <= '2025-09-03 12:00:00'
206+
GROUP BY session_id
207+
HAVING expiry_time >= '2025-09-03 12:00:00'
208+
'''
209+
# ---
210+
# name: TestSessionSummarizationNodeFilterGeneration.test_validate_specific_session_ids_filters_invalid
211+
'''
212+
213+
SELECT session_id,
214+
min(min_first_timestamp) as min_timestamp,
215+
max(max_last_timestamp) as max_timestamp,
216+
max(retention_period_days) as retention_period_days,
217+
dateTrunc('DAY', min_timestamp) + toIntervalDay(coalesce(retention_period_days, 21)) as expiry_time
218+
FROM session_replay_events PREWHERE team_id = 99999
219+
AND session_id IN ['01990f72-b600-7fa3-9a77-341582154177', '01990f72-b600-76df-a71a-f8d777d51361', 'nonexistent-session-1', 'nonexistent-session-2']
220+
AND min_first_timestamp <= '2025-09-03 12:00:00'
221+
GROUP BY session_id
222+
HAVING expiry_time >= '2025-09-03 12:00:00'
223+
'''
224+
# ---

ee/hogai/chat_agent/session_summaries/test/test_nodes.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,12 +595,17 @@ def test_arun_both_current_session_and_filters_true_returns_error(self) -> None:
595595
self.assertIn("encountered an issue", message.content)
596596

597597
@patch("ee.hogai.chat_agent.session_summaries.nodes.execute_summarize_session")
598+
@patch("ee.hogai.chat_agent.session_summaries.nodes._SessionSearch._validate_specific_session_ids")
598599
@patch("ee.hogai.chat_agent.session_summaries.nodes.GROUP_SUMMARIES_MIN_SESSIONS", 5)
599-
def test_arun_use_current_session_with_session_id(self, mock_execute_summarize: MagicMock) -> None:
600+
def test_arun_use_current_session_with_session_id(
601+
self, mock_validate_session_ids: MagicMock, mock_execute_summarize: MagicMock
602+
) -> None:
600603
"""Test arun uses current session ID when specific_session_ids_to_summarize are provided."""
601604
conversation = Conversation.objects.create(team=self.team, user=self.user)
602605
session_id = "00000000-0000-0000-0000-000000000001"
603606

607+
mock_validate_session_ids.return_value = [session_id]
608+
604609
async def mock_summarize_side_effect(*args: Any, **kwargs: Any) -> dict[str, Any]:
605610
return self._session_template(session_id)
606611

@@ -936,3 +941,26 @@ def test_get_session_ids_respects_limit(self) -> None:
936941
self.assertIsNotNone(session_ids)
937942
assert session_ids is not None # Type narrowing for mypy
938943
self.assertEqual(len(session_ids), 1, "Should return exactly 1 session due to limit")
944+
945+
@freeze_time("2025-09-03T12:00:00")
946+
def test_validate_specific_session_ids_filters_invalid(self) -> None:
947+
"""Test that invalid session IDs are filtered out while valid ones are kept."""
948+
valid_ids = [self.session_id_1, self.session_id_2]
949+
invalid_ids = ["nonexistent-session-1", "nonexistent-session-2"]
950+
all_ids = valid_ids + invalid_ids
951+
952+
result = self.node._session_search._validate_specific_session_ids(all_ids)
953+
954+
self.assertIsNotNone(result)
955+
assert result is not None
956+
self.assertEqual(len(result), 2)
957+
self.assertEqual(result, valid_ids)
958+
959+
@freeze_time("2025-09-03T12:00:00")
960+
def test_validate_specific_session_ids_all_invalid_returns_none(self) -> None:
961+
"""Test that when all session IDs are invalid, None is returned."""
962+
invalid_ids = ["nonexistent-session-1", "nonexistent-session-2"]
963+
964+
result = self.node._session_search._validate_specific_session_ids(invalid_ids)
965+
966+
self.assertIsNone(result)

0 commit comments

Comments
 (0)