Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ee/hogai/session_summaries/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
HALLUCINATED_EVENTS_MIN_RATIO = 0.15 # If more than 15% of events in the summary hallucinated, fail the summarization
# Minimum number of sessions to use group summary logic (find patterns) instead of summarizing them separately
GROUP_SUMMARIES_MIN_SESSIONS = 5
EVENTS_BEFORE_REPLAY_START_THRESHOLD_MS = (
1000 # Don't include events before the cutoff as they are not visible in the replay
)

# Temporal
SESSION_SUMMARIES_DB_DATA_REDIS_TTL = 60 * 60 * 24 # How long to store the DB data in Redis within Temporal jobs
Expand Down
18 changes: 17 additions & 1 deletion ee/hogai/session_summaries/session/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from posthog.session_recordings.models.metadata import RecordingMetadata
from posthog.session_recordings.queries.session_replay_events import SessionReplayEvents

from ee.hogai.session_summaries.constants import EVENTS_BEFORE_REPLAY_START_THRESHOLD_MS
from ee.hogai.session_summaries.local.input_data import (
_get_production_session_events_locally,
_get_production_session_metadata_locally,
)
from ee.hogai.session_summaries.utils import get_column_index
from ee.hogai.session_summaries.utils import calculate_time_since_start, get_column_index

logger = structlog.get_logger(__name__)

Expand Down Expand Up @@ -179,7 +180,9 @@ def add_context_and_filter_events(
session_events_columns: list[str],
session_events: list[tuple[str | datetime.datetime | list[str] | None, ...]],
session_id: str,
session_start_time: datetime.datetime,
) -> tuple[list[str], list[tuple[str | datetime.datetime | list[str] | None, ...]]]:
timestamp_index = get_column_index(session_events_columns, "timestamp")
indexes = {
"event": get_column_index(session_events_columns, "event"),
"$event_type": get_column_index(session_events_columns, "$event_type"),
Expand All @@ -199,7 +202,20 @@ def add_context_and_filter_events(
i for i, col in enumerate(session_events_columns) if col not in COLUMNS_TO_REMOVE_FROM_LLM_CONTEXT
]
updated_events = []
# Events are chronologically ordered, so once we find an event after replay start, all subsequent events are too
past_replay_start = False
for event in session_events:
# Filter out events that occurred before or exactly at replay start, as we can't confirm them with video
if not past_replay_start:
event_timestamp = event[timestamp_index]
if not isinstance(event_timestamp, str) and not isinstance(event_timestamp, datetime.datetime):
msg = f"Event timestamp is not a string or datetime: {event_timestamp}"
logger.error(msg, signals_type="session-summaries", session_id=session_id)
raise ValueError(msg)
ms_since_start = calculate_time_since_start(event_timestamp, session_start_time)
if ms_since_start <= EVENTS_BEFORE_REPLAY_START_THRESHOLD_MS:
continue
past_replay_start = True
updated_event: list[str | datetime.datetime | list[str] | None] = list(event)
# Check for errors worth keeping in the context
if event[indexes["event"]] == "$exception":
Expand Down
35 changes: 18 additions & 17 deletions ee/hogai/session_summaries/session/output_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime
from enum import Enum
from typing import Any

Expand All @@ -7,7 +6,12 @@

from ee.hogai.session_summaries import SummaryValidationError
from ee.hogai.session_summaries.constants import HALLUCINATED_EVENTS_MIN_RATIO
from ee.hogai.session_summaries.utils import get_column_index, prepare_datetime, unpack_full_event_id
from ee.hogai.session_summaries.utils import (
calculate_time_since_start,
get_column_index,
prepare_datetime,
unpack_full_event_id,
)
from ee.hogai.utils.yaml import load_yaml_from_raw_llm_content

logger = structlog.get_logger(__name__)
Expand Down Expand Up @@ -274,7 +278,8 @@ def _remove_hallucinated_events(
f"Too many hallucinated events ({len(hallucinated_events)}/{total_summary_events}) for session id ({session_id})"
f"in the raw session summary: {[x[-1] for x in hallucinated_events]} "
)
logger.error(msg, session_id=session_id, signals_type="session-summaries")
if final_validation:
logger.error(msg, session_id=session_id, signals_type="session-summaries")
raise SummaryValidationError(msg)
# Reverse to not break indexes
for group_index, event_index, event in reversed(hallucinated_events):
Expand Down Expand Up @@ -362,14 +367,6 @@ def load_raw_session_summary_from_llm_content(
return raw_session_summary


# TODO Rework the logic, so events before the recording are marked as "LOAD", not 00:00
def calculate_time_since_start(session_timestamp: str, session_start_time: datetime | None) -> int | None:
if not session_start_time or not session_timestamp:
return None
timestamp_datetime = datetime.fromisoformat(session_timestamp)
return max(0, int((timestamp_datetime - session_start_time).total_seconds() * 1000))


def _validate_enriched_summary(
data: dict[str, Any], session_id: str, final_validation: bool
) -> SessionSummarySerializer:
Expand Down Expand Up @@ -454,6 +451,7 @@ def _calculate_segment_meta(
raw_key_actions: list[dict[str, Any]] | None,
session_duration: int,
session_id: str,
final_validation: bool,
) -> SegmentMetaSerializer:
# Find first and the last event in the segment
segment_index = raw_segment.get("index")
Expand Down Expand Up @@ -599,11 +597,12 @@ def _calculate_segment_meta(
# TODO: Factor of two is arbitrary, find a better solution
if duration <= 0 or fallback_duration // duration > 2:
# Checking only duration as events are sorted chronologically
logger.warning(
f"Duration change is drastic (fallback: {fallback_duration} -> segments: {duration}) - using fallback data for session_id {session_id}",
session_id=session_id,
signals_type="session-summaries",
)
if final_validation:
logger.warning(
f"Duration change is drastic (fallback: {fallback_duration} -> segments: {duration}) - using fallback data for session_id {session_id}",
session_id=session_id,
signals_type="session-summaries",
)
segment_meta_data["duration"] = fallback_duration
segment_meta_data["duration_percentage"] = fallback_duration_percentage
segment_meta_data["events_count"] = fallback_events_count
Expand Down Expand Up @@ -657,12 +656,14 @@ def enrich_raw_session_summary_with_meta(
simplified_events_mapping=simplified_events_mapping,
raw_key_actions=raw_key_actions,
session_id=session_id,
final_validation=final_validation,
)
# Validate the serializer to be able to use `.data`
if not segment_meta.is_valid():
# Most of the fields are optional, so failed validation should be reported
msg = f"Error validating segment meta against the schema when summarizing session_id {session_id}: {segment_meta.errors}"
logger.error(msg, session_id=session_id, signals_type="session-summaries")
if final_validation:
logger.error(msg, session_id=session_id, signals_type="session-summaries")
raise SummaryValidationError(msg)
enriched_segment["meta"] = segment_meta.data
enriched_segments.append(enriched_segment)
Expand Down
5 changes: 4 additions & 1 deletion ee/hogai/session_summaries/session/summarize_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ async def get_session_data_from_db(session_id: str, team_id: int, local_reads_pr
# Raise any unexpected errors
raise
session_events_columns, session_events = add_context_and_filter_events(
session_events_columns=session_events_columns, session_events=session_events, session_id=session_id
session_events_columns=session_events_columns,
session_events=session_events,
session_id=session_id,
session_start_time=session_metadata["start_time"],
)

# TODO Get web analytics data on URLs to better understand what the user was doing
Expand Down
10 changes: 8 additions & 2 deletions ee/hogai/session_summaries/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,13 @@ def mock_chat_completion(mock_valid_llm_yaml_response: str) -> ChatCompletion:


@pytest.fixture
def mock_raw_metadata(mock_session_id: str) -> dict[str, Any]:
def mock_session_start_time() -> datetime:
"""Session replay start time - events before this should be filtered out"""
return datetime(2025, 3, 31, 18, 40, 32, 302000, tzinfo=UTC)


@pytest.fixture
def mock_raw_metadata(mock_session_id: str, mock_session_start_time: datetime) -> dict[str, Any]:
return {
"id": mock_session_id,
# Anonymized distinct_id for testing
Expand All @@ -486,7 +492,7 @@ def mock_raw_metadata(mock_session_id: str) -> dict[str, Any]:
"recording_duration": 5323,
"active_seconds": 1947,
"inactive_seconds": 3375,
"start_time": "2025-03-31T18:40:32.302000Z",
"start_time": mock_session_start_time,
"end_time": "2025-03-31T18:54:15.789000Z",
"click_count": 679,
"keypress_count": 668,
Expand Down
107 changes: 98 additions & 9 deletions ee/hogai/session_summaries/tests/test_input_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from datetime import datetime
from datetime import UTC, datetime
from typing import Any

import pytest
Expand All @@ -17,6 +17,9 @@
get_session_events,
)

# Timestamp after the mock session start time (2025-03-31T18:40:32.302000Z)
MOCK_EVENT_TIMESTAMP = datetime(2025, 3, 31, 18, 40, 39, 302000, tzinfo=UTC)


@pytest.fixture
def mock_event_indexes(mock_raw_events_columns: list[str]) -> dict[str, int]:
Expand Down Expand Up @@ -95,7 +98,7 @@ def test_get_improved_elements_chain_elements():
(
(
"$autocapture",
None,
MOCK_EVENT_TIMESTAMP,
"",
["Click me"],
[],
Expand All @@ -113,7 +116,7 @@ def test_get_improved_elements_chain_elements():
),
(
"$autocapture",
None,
MOCK_EVENT_TIMESTAMP,
"",
["Click me"],
["button"],
Expand All @@ -131,7 +134,7 @@ def test_get_improved_elements_chain_elements():
(
(
"$autocapture",
None,
MOCK_EVENT_TIMESTAMP,
"",
[],
[],
Expand All @@ -155,7 +158,7 @@ def test_get_improved_elements_chain_elements():
(
(
"$autocapture",
None,
MOCK_EVENT_TIMESTAMP,
"",
[],
[],
Expand All @@ -173,7 +176,7 @@ def test_get_improved_elements_chain_elements():
),
(
"$autocapture",
None,
MOCK_EVENT_TIMESTAMP,
"",
["Click me", "Create project"],
["button", "a"],
Expand All @@ -191,7 +194,7 @@ def test_get_improved_elements_chain_elements():
(
(
"user_clicked_button",
None,
MOCK_EVENT_TIMESTAMP,
"",
[],
[],
Expand All @@ -209,7 +212,7 @@ def test_get_improved_elements_chain_elements():
),
(
"user_clicked_button",
None,
MOCK_EVENT_TIMESTAMP,
"",
[],
[],
Expand All @@ -227,13 +230,17 @@ def test_get_improved_elements_chain_elements():
)
def test_add_context_and_filter_events(
mock_event_indexes: dict[str, int],
mock_session_start_time: datetime,
input_event: tuple[Any, ...],
expected_event: tuple[Any, ...] | None,
should_keep: bool,
):
test_columns = list(mock_event_indexes.keys())
updated_columns, updated_events = add_context_and_filter_events(
session_events_columns=test_columns, session_events=[input_event], session_id="test_session_id"
session_events_columns=test_columns,
session_events=[input_event],
session_id="test_session_id",
session_start_time=mock_session_start_time,
)

# Check columns are updated (and columns excessive from LLM context are removed)
Expand All @@ -249,6 +256,88 @@ def test_add_context_and_filter_events(
assert len(updated_events) == 0


@pytest.mark.parametrize(
"event_timestamps,expected_kept_count",
[
# All events before replay start - none kept
(
[
datetime(2025, 3, 31, 18, 40, 30, 0, tzinfo=UTC), # Before start
datetime(2025, 3, 31, 18, 40, 31, 0, tzinfo=UTC), # Before start
datetime(2025, 3, 31, 18, 40, 32, 302000, tzinfo=UTC), # Exactly at start (filtered)
],
0,
),
# First event before, second at start, third after - only third kept
(
[
datetime(2025, 3, 31, 18, 40, 30, 0, tzinfo=UTC), # Before start
datetime(2025, 3, 31, 18, 40, 32, 302000, tzinfo=UTC), # Exactly at start (filtered)
datetime(2025, 3, 31, 18, 40, 33, 0, tzinfo=UTC), # After start
],
1,
),
# All events after replay start - all kept
(
[
datetime(2025, 3, 31, 18, 40, 33, 0, tzinfo=UTC), # After start
datetime(2025, 3, 31, 18, 40, 34, 0, tzinfo=UTC), # After start
datetime(2025, 3, 31, 18, 40, 35, 0, tzinfo=UTC), # After start
],
3,
),
# Mix: two before, one after - one kept
(
[
datetime(2025, 3, 31, 18, 40, 30, 0, tzinfo=UTC), # Before start
datetime(2025, 3, 31, 18, 40, 31, 0, tzinfo=UTC), # Before start
datetime(2025, 3, 31, 18, 40, 39, 302000, tzinfo=UTC), # After start
],
1,
),
],
)
def test_filter_events_before_replay_start(
mock_raw_events_columns: list[str],
mock_session_start_time: datetime,
event_timestamps: list[datetime],
expected_kept_count: int,
):
"""Test that events occurring before or exactly at replay start are filtered out."""
# Create events with different timestamps but valid context (so they're not filtered for other reasons)
events: list[tuple[Any, ...]] = []
for i, ts in enumerate(event_timestamps):
events.append(
(
"$pageview", # System event - not filtered for lack of context
ts,
"",
[],
[],
None,
None,
None,
[],
"",
[],
[],
[],
[],
[],
f"00000000-0000-0000-0001-00000000000{i}",
)
)

updated_columns, updated_events = add_context_and_filter_events(
session_events_columns=mock_raw_events_columns,
session_events=events,
session_id="test_session_id",
session_start_time=mock_session_start_time,
)

assert len(updated_events) == expected_kept_count


@pytest.mark.parametrize(
"pages_data,expected_count,expected_iterations,expected_error",
[
Expand Down
6 changes: 3 additions & 3 deletions ee/hogai/session_summaries/tests/test_output_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ def test_load_raw_session_summary_invalid_schema(
("2024-03-01T12:00:02+00:00", datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC), 2000), # 2 seconds after
("2024-03-01T12:00:00+00:00", datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC), 0), # same time
("2024-03-01T11:59:59+00:00", datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC), 0), # 1 second before (clamped to 0)
(None, datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC), None), # no event time
("2024-03-01T12:00:02+00:00", None, None), # no start time
("2024-03-01T13:00:00+00:00", datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC), 3600000), # 1 hour after
# Also accepts datetime objects for event_time
(datetime(2024, 3, 1, 12, 0, 2, tzinfo=UTC), datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC), 2000),
],
)
def test_calculate_time_since_start(event_time: str, start_time: datetime, expected: int) -> None:
def test_calculate_time_since_start(event_time: str | datetime, start_time: datetime, expected: int) -> None:
result = calculate_time_since_start(event_time, start_time)
assert result == expected

Expand Down
10 changes: 10 additions & 0 deletions ee/hogai/session_summaries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,13 @@ def logging_session_ids(session_ids: list[str]) -> str:
"""Log a list of session ids in a readable format."""
# Having 150 chars (4 uuids) is enough to identify the sessions and stay readable
return f"Session IDs: {str(session_ids)[:MAX_SESSION_IDS_COMBINED_LOGGING_LENGTH]}"


def calculate_time_since_start(event_timestamp: str | datetime, session_start_time: datetime) -> int:
"""
Calculate milliseconds between event timestamp and session start time.
Returns 0 for events that occurred before or exactly at session start.
"""
if isinstance(event_timestamp, str):
event_timestamp = datetime.fromisoformat(event_timestamp)
return max(0, int((event_timestamp - session_start_time).total_seconds() * 1000))
Loading
Loading