Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions ee/hogai/chat_agent/session_summaries/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,20 @@ def _get_session_ids_with_filters(self, replay_filters: RecordingsQuery) -> list
session_ids = [recording["session_id"] for recording in results.results]
return session_ids if session_ids else None

def _validate_specific_session_ids(self, session_ids: list[str]) -> list[str] | None:
"""Validate that specific session IDs exist in the database."""
from posthog.session_recordings.queries.session_replay_events import SessionReplayEvents

replay_events = SessionReplayEvents()
sessions_found, _, _ = replay_events.sessions_found_with_timestamps(
session_ids=session_ids,
team=self._node._team,
)
if not sessions_found:
return None
# Preserve the original order, filtering out invalid sessions
return [sid for sid in session_ids if sid in sessions_found]

async def _generate_filter_query(self, plain_text_query: str, config: RunnableConfig) -> str:
"""Generate a filter query for the user's summarization query to keep the search context clear"""
messages = [
Expand Down Expand Up @@ -408,8 +422,13 @@ async def search_sessions(
return self._node._create_error_response(self._node._base_error_instructions, state)
# Use specific session IDs, if provided
if state.specific_session_ids_to_summarize:
# Return session ids right away to use in the next step
return state.specific_session_ids_to_summarize
# Validate that sessions exist before using them
valid_session_ids = await database_sync_to_async(
self._validate_specific_session_ids, thread_sensitive=False
)(state.specific_session_ids_to_summarize)
if not valid_session_ids:
return None
return valid_session_ids
# Use current filters, if provided
if state.should_use_current_filters:
if not current_filters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,33 @@
use_hive_partitioning=0
'''
# ---
# name: TestSessionSummarizationNodeFilterGeneration.test_validate_specific_session_ids_all_invalid_returns_none
'''

SELECT session_id,
min(min_first_timestamp) as min_timestamp,
max(max_last_timestamp) as max_timestamp,
max(retention_period_days) as retention_period_days,
dateTrunc('DAY', min_timestamp) + toIntervalDay(coalesce(retention_period_days, 21)) as expiry_time
FROM session_replay_events PREWHERE team_id = 99999
AND session_id IN ['nonexistent-session-1', 'nonexistent-session-2']
AND min_first_timestamp <= '2025-09-03 12:00:00'
GROUP BY session_id
HAVING expiry_time >= '2025-09-03 12:00:00'
'''
# ---
# name: TestSessionSummarizationNodeFilterGeneration.test_validate_specific_session_ids_filters_invalid
'''

SELECT session_id,
min(min_first_timestamp) as min_timestamp,
max(max_last_timestamp) as max_timestamp,
max(retention_period_days) as retention_period_days,
dateTrunc('DAY', min_timestamp) + toIntervalDay(coalesce(retention_period_days, 21)) as expiry_time
FROM session_replay_events PREWHERE team_id = 99999
AND session_id IN ['01990f72-b600-7fa3-9a77-341582154177', '01990f72-b600-76df-a71a-f8d777d51361', 'nonexistent-session-1', 'nonexistent-session-2']
AND min_first_timestamp <= '2025-09-03 12:00:00'
GROUP BY session_id
HAVING expiry_time >= '2025-09-03 12:00:00'
'''
# ---
30 changes: 29 additions & 1 deletion ee/hogai/chat_agent/session_summaries/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,17 @@ def test_arun_both_current_session_and_filters_true_returns_error(self) -> None:
self.assertIn("encountered an issue", message.content)

@patch("ee.hogai.chat_agent.session_summaries.nodes.execute_summarize_session")
@patch("ee.hogai.chat_agent.session_summaries.nodes._SessionSearch._validate_specific_session_ids")
@patch("ee.hogai.chat_agent.session_summaries.nodes.GROUP_SUMMARIES_MIN_SESSIONS", 5)
def test_arun_use_current_session_with_session_id(self, mock_execute_summarize: MagicMock) -> None:
def test_arun_use_current_session_with_session_id(
self, mock_validate_session_ids: MagicMock, mock_execute_summarize: MagicMock
) -> None:
"""Test arun uses current session ID when specific_session_ids_to_summarize are provided."""
conversation = Conversation.objects.create(team=self.team, user=self.user)
session_id = "00000000-0000-0000-0000-000000000001"

mock_validate_session_ids.return_value = [session_id]

async def mock_summarize_side_effect(*args: Any, **kwargs: Any) -> dict[str, Any]:
return self._session_template(session_id)

Expand Down Expand Up @@ -936,3 +941,26 @@ def test_get_session_ids_respects_limit(self) -> None:
self.assertIsNotNone(session_ids)
assert session_ids is not None # Type narrowing for mypy
self.assertEqual(len(session_ids), 1, "Should return exactly 1 session due to limit")

@freeze_time("2025-09-03T12:00:00")
def test_validate_specific_session_ids_filters_invalid(self) -> None:
"""Test that invalid session IDs are filtered out while valid ones are kept."""
valid_ids = [self.session_id_1, self.session_id_2]
invalid_ids = ["nonexistent-session-1", "nonexistent-session-2"]
all_ids = valid_ids + invalid_ids

result = self.node._session_search._validate_specific_session_ids(all_ids)

self.assertIsNotNone(result)
assert result is not None
self.assertEqual(len(result), 2)
self.assertEqual(result, valid_ids)

@freeze_time("2025-09-03T12:00:00")
def test_validate_specific_session_ids_all_invalid_returns_none(self) -> None:
"""Test that when all session IDs are invalid, None is returned."""
invalid_ids = ["nonexistent-session-1", "nonexistent-session-2"]

result = self.node._session_search._validate_specific_session_ids(invalid_ids)

self.assertIsNone(result)
Loading