diff --git a/ee/hogai/chat_agent/session_summaries/nodes.py b/ee/hogai/chat_agent/session_summaries/nodes.py index 8233e747ae9b1..c296762c90550 100644 --- a/ee/hogai/chat_agent/session_summaries/nodes.py +++ b/ee/hogai/chat_agent/session_summaries/nodes.py @@ -355,6 +355,20 @@ def _get_session_ids_with_filters(self, replay_filters: RecordingsQuery) -> list session_ids = [recording["session_id"] for recording in results.results] return session_ids if session_ids else None + def _validate_specific_session_ids(self, session_ids: list[str]) -> list[str] | None: + """Validate that specific session IDs exist in the database.""" + from posthog.session_recordings.queries.session_replay_events import SessionReplayEvents + + replay_events = SessionReplayEvents() + sessions_found, _, _ = replay_events.sessions_found_with_timestamps( + session_ids=session_ids, + team=self._node._team, + ) + if not sessions_found: + return None + # Preserve the original order, filtering out invalid sessions + return [sid for sid in session_ids if sid in sessions_found] + async def _generate_filter_query(self, plain_text_query: str, config: RunnableConfig) -> str: """Generate a filter query for the user's summarization query to keep the search context clear""" messages = [ @@ -408,8 +422,13 @@ async def search_sessions( return self._node._create_error_response(self._node._base_error_instructions, state) # Use specific session IDs, if provided if state.specific_session_ids_to_summarize: - # Return session ids right away to use in the next step - return state.specific_session_ids_to_summarize + # Validate that sessions exist before using them + valid_session_ids = await database_sync_to_async( + self._validate_specific_session_ids, thread_sensitive=False + )(state.specific_session_ids_to_summarize) + if not valid_session_ids: + return None + return valid_session_ids # Use current filters, if provided if state.should_use_current_filters: if not current_filters: diff --git a/ee/hogai/chat_agent/session_summaries/test/__snapshots__/test_nodes.ambr b/ee/hogai/chat_agent/session_summaries/test/__snapshots__/test_nodes.ambr index 0b2dddea843ea..ef595d22eaf57 100644 --- a/ee/hogai/chat_agent/session_summaries/test/__snapshots__/test_nodes.ambr +++ b/ee/hogai/chat_agent/session_summaries/test/__snapshots__/test_nodes.ambr @@ -192,3 +192,33 @@ use_hive_partitioning=0 ''' # --- +# name: TestSessionSummarizationNodeFilterGeneration.test_validate_specific_session_ids_all_invalid_returns_none + ''' + + SELECT session_id, + min(min_first_timestamp) as min_timestamp, + max(max_last_timestamp) as max_timestamp, + max(retention_period_days) as retention_period_days, + dateTrunc('DAY', min_timestamp) + toIntervalDay(coalesce(retention_period_days, 21)) as expiry_time + FROM session_replay_events PREWHERE team_id = 99999 + AND session_id IN ['nonexistent-session-1', 'nonexistent-session-2'] + AND min_first_timestamp <= '2025-09-03 12:00:00' + GROUP BY session_id + HAVING expiry_time >= '2025-09-03 12:00:00' + ''' +# --- +# name: TestSessionSummarizationNodeFilterGeneration.test_validate_specific_session_ids_filters_invalid + ''' + + SELECT session_id, + min(min_first_timestamp) as min_timestamp, + max(max_last_timestamp) as max_timestamp, + max(retention_period_days) as retention_period_days, + dateTrunc('DAY', min_timestamp) + toIntervalDay(coalesce(retention_period_days, 21)) as expiry_time + FROM session_replay_events PREWHERE team_id = 99999 + AND session_id IN ['01990f72-b600-7fa3-9a77-341582154177', '01990f72-b600-76df-a71a-f8d777d51361', 'nonexistent-session-1', 'nonexistent-session-2'] + AND min_first_timestamp <= '2025-09-03 12:00:00' + GROUP BY session_id + HAVING expiry_time >= '2025-09-03 12:00:00' + ''' +# --- diff --git a/ee/hogai/chat_agent/session_summaries/test/test_nodes.py b/ee/hogai/chat_agent/session_summaries/test/test_nodes.py index 911648d1a862f..2546688a6d208 100644 --- a/ee/hogai/chat_agent/session_summaries/test/test_nodes.py +++ b/ee/hogai/chat_agent/session_summaries/test/test_nodes.py @@ -595,12 +595,17 @@ def test_arun_both_current_session_and_filters_true_returns_error(self) -> None: self.assertIn("encountered an issue", message.content) @patch("ee.hogai.chat_agent.session_summaries.nodes.execute_summarize_session") + @patch("ee.hogai.chat_agent.session_summaries.nodes._SessionSearch._validate_specific_session_ids") @patch("ee.hogai.chat_agent.session_summaries.nodes.GROUP_SUMMARIES_MIN_SESSIONS", 5) - def test_arun_use_current_session_with_session_id(self, mock_execute_summarize: MagicMock) -> None: + def test_arun_use_current_session_with_session_id( + self, mock_validate_session_ids: MagicMock, mock_execute_summarize: MagicMock + ) -> None: """Test arun uses current session ID when specific_session_ids_to_summarize are provided.""" conversation = Conversation.objects.create(team=self.team, user=self.user) session_id = "00000000-0000-0000-0000-000000000001" + mock_validate_session_ids.return_value = [session_id] + async def mock_summarize_side_effect(*args: Any, **kwargs: Any) -> dict[str, Any]: return self._session_template(session_id) @@ -936,3 +941,26 @@ def test_get_session_ids_respects_limit(self) -> None: self.assertIsNotNone(session_ids) assert session_ids is not None # Type narrowing for mypy self.assertEqual(len(session_ids), 1, "Should return exactly 1 session due to limit") + + @freeze_time("2025-09-03T12:00:00") + def test_validate_specific_session_ids_filters_invalid(self) -> None: + """Test that invalid session IDs are filtered out while valid ones are kept.""" + valid_ids = [self.session_id_1, self.session_id_2] + invalid_ids = ["nonexistent-session-1", "nonexistent-session-2"] + all_ids = valid_ids + invalid_ids + + result = self.node._session_search._validate_specific_session_ids(all_ids) + + self.assertIsNotNone(result) + assert result is not None + self.assertEqual(len(result), 2) + self.assertEqual(result, valid_ids) + + @freeze_time("2025-09-03T12:00:00") + def test_validate_specific_session_ids_all_invalid_returns_none(self) -> None: + """Test that when all session IDs are invalid, None is returned.""" + invalid_ids = ["nonexistent-session-1", "nonexistent-session-2"] + + result = self.node._session_search._validate_specific_session_ids(invalid_ids) + + self.assertIsNone(result)