Skip to content

Commit 626c8f6

Browse files
kappa90sakce
authored andcommitted
feat(ph-ai): use artifacts for visualizations (#41947)
## Problem We want to migrate to smarter artifact management. This PR introduces an artifact manager that can handle visualization objects, either created by the assistant, or retrieved from the user's ones. ## Changes - Created a new `ArtifactManager` class to handle creation, retrieval, and enrichment of agent artifacts - Updated stream processing to handle artifact messages and convert them to appropriate visualization formats - Refactored frontend component to use artifact visualizations. ## How did you test this code? - Locally, tests to be fixed
1 parent 9e31449 commit 626c8f6

File tree

61 files changed

+3152
-1057
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+3152
-1057
lines changed

ee/hogai/api/serializers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from posthog.api.shared import UserBasicSerializer
99
from posthog.exceptions_capture import capture_exception
1010

11+
from ee.hogai.artifacts.manager import ArtifactManager
1112
from ee.hogai.chat_agent import AssistantGraph
1213
from ee.hogai.utils.helpers import should_output_assistant_message
1314
from ee.hogai.utils.types import AssistantState
@@ -75,8 +76,11 @@ async def _aget_messages_with_flag(self, conversation: Conversation) -> tuple[li
7576
{"configurable": {"thread_id": str(conversation.id), "checkpoint_ns": ""}}
7677
)
7778
state = state_class.model_validate(snapshot.values)
78-
messages = [message.model_dump() for message in state.messages if should_output_assistant_message(message)]
79-
return messages, False
79+
messages = list(state.messages)
80+
artifact_manager = ArtifactManager(team, user)
81+
enriched_messages = await artifact_manager.aenrich_messages(messages)
82+
return [m.model_dump() for m in enriched_messages if should_output_assistant_message(m)], False
83+
8084
except pydantic.ValidationError as e:
8185
capture_exception(
8286
e,

ee/hogai/api/test/test_serializers.py

Lines changed: 110 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1+
from uuid import uuid4
2+
13
from posthog.test.base import APIBaseTest
24
from unittest.mock import AsyncMock, patch
35

4-
from posthog.schema import AssistantMessage, AssistantToolCallMessage, ContextMessage
6+
from posthog.schema import (
7+
ArtifactContentType,
8+
ArtifactSource,
9+
AssistantMessage,
10+
AssistantToolCallMessage,
11+
ContextMessage,
12+
)
513

614
from ee.hogai.api.serializers import ConversationSerializer
715
from ee.hogai.chat_agent import AssistantGraph
816
from ee.hogai.utils.types import AssistantState
9-
from ee.models.assistant import Conversation
17+
from ee.hogai.utils.types.base import ArtifactRefMessage
18+
from ee.models.assistant import AgentArtifact, Conversation
1019

1120

1221
class TestConversationSerializers(APIBaseTest):
@@ -67,8 +76,8 @@ class MockSnapshot:
6776
# Third message should be the AssistantToolCallMessage without UI payload
6877
self.assertEqual(filtered_messages[2]["ui_payload"], None)
6978

70-
def test_get_messages_handles_validation_errors(self):
71-
"""Gracefully fall back to an empty list when the stored state fails validation."""
79+
def test_get_messages_handles_validation_errors_and_sets_unsupported_content(self):
80+
"""Gracefully fall back to an empty list when the stored state fails validation, and set has_unsupported_content."""
7281
conversation = Conversation.objects.create(
7382
user=self.user, team=self.team, title="Conversation with invalid state", type=Conversation.Type.ASSISTANT
7483
)
@@ -88,40 +97,74 @@ def test_get_messages_handles_validation_errors(self):
8897
).data
8998

9099
self.assertEqual(data["messages"], [])
100+
self.assertTrue(data["has_unsupported_content"])
91101

92-
def test_has_unsupported_content_on_validation_error(self):
93-
"""When validation fails, has_unsupported_content should be True."""
102+
def test_caching_prevents_duplicate_operations(self):
103+
"""This is to test that the caching works correctly as to not incurring in unnecessary operations (We would do a DRF call per field call)."""
94104
conversation = Conversation.objects.create(
95-
user=self.user,
96-
team=self.team,
97-
title="Conversation with schema mismatch",
98-
type=Conversation.Type.ASSISTANT,
105+
user=self.user, team=self.team, title="Cached conversation", type=Conversation.Type.ASSISTANT
99106
)
100107

101-
invalid_snapshot = type("Snapshot", (), {"values": {"messages": [{"invalid": "schema"}]}})()
108+
state = AssistantState(messages=[AssistantMessage(content="Cached message", type="ai")])
102109

103110
with patch("langgraph.graph.state.CompiledStateGraph.aget_state", new_callable=AsyncMock) as mock_get_state:
104-
mock_get_state.return_value = invalid_snapshot
105111

106-
data = ConversationSerializer(
112+
class MockSnapshot:
113+
values = state.model_dump()
114+
115+
mock_get_state.return_value = MockSnapshot()
116+
117+
serializer = ConversationSerializer(
107118
conversation,
108119
context={
109120
"team": self.team,
110121
"user": self.user,
111122
},
112-
).data
123+
)
113124

114-
self.assertEqual(data["messages"], [])
115-
self.assertTrue(data["has_unsupported_content"])
125+
# Explicitly access both fields multiple times
126+
_ = serializer.data["messages"]
127+
_ = serializer.data["has_unsupported_content"]
128+
_ = serializer.data["messages"]
129+
_ = serializer.data["has_unsupported_content"]
116130

117-
def test_has_unsupported_content_on_other_errors(self):
118-
"""On non-validation errors, has_unsupported_content should be False."""
131+
# aget_state should only be called once though
132+
self.assertEqual(mock_get_state.call_count, 1)
133+
134+
135+
class TestConversationSerializerArtifactEnrichment(APIBaseTest):
136+
"""Test artifact enrichment functionality in the serializer."""
137+
138+
def test_artifact_ref_message_enriched_in_response(self):
139+
"""Test that ArtifactRefMessage is enriched with content from database artifact."""
119140
conversation = Conversation.objects.create(
120-
user=self.user, team=self.team, title="Conversation with graph error", type=Conversation.Type.ASSISTANT
141+
user=self.user, team=self.team, title="Artifact test conversation", type=Conversation.Type.ASSISTANT
121142
)
122143

144+
# Create an artifact in the database
145+
artifact = AgentArtifact.objects.create(
146+
name="Test Artifact",
147+
type=AgentArtifact.Type.VISUALIZATION,
148+
data={"query": {"kind": "TrendsQuery", "series": []}, "name": "Chart Name"},
149+
conversation=conversation,
150+
team=self.team,
151+
)
152+
153+
# Create state with an ArtifactRefMessage
154+
artifact_message = ArtifactRefMessage(
155+
id=str(uuid4()),
156+
content_type=ArtifactContentType.VISUALIZATION,
157+
artifact_id=artifact.short_id,
158+
source=ArtifactSource.ARTIFACT,
159+
)
160+
state = AssistantState(messages=[artifact_message])
161+
123162
with patch("langgraph.graph.state.CompiledStateGraph.aget_state", new_callable=AsyncMock) as mock_get_state:
124-
mock_get_state.side_effect = RuntimeError("Graph compilation failed")
163+
164+
class MockSnapshot:
165+
values = state.model_dump()
166+
167+
mock_get_state.return_value = MockSnapshot()
125168

126169
data = ConversationSerializer(
127170
conversation,
@@ -131,16 +174,27 @@ def test_has_unsupported_content_on_other_errors(self):
131174
},
132175
).data
133176

134-
self.assertEqual(data["messages"], [])
135-
self.assertFalse(data["has_unsupported_content"])
177+
# The message should be enriched as an ArtifactMessage
178+
self.assertEqual(len(data["messages"]), 1)
179+
enriched_msg = data["messages"][0]
180+
self.assertEqual(enriched_msg["type"], "ai/artifact")
181+
self.assertEqual(enriched_msg["artifact_id"], artifact.short_id)
182+
self.assertEqual(enriched_msg["content"]["name"], "Chart Name")
136183

137-
def test_has_unsupported_content_on_success(self):
138-
"""On successful message fetch, has_unsupported_content should be False."""
184+
def test_artifact_ref_message_filtered_when_not_found(self):
185+
"""Test that ArtifactRefMessage is filtered out when artifact not found in database."""
139186
conversation = Conversation.objects.create(
140-
user=self.user, team=self.team, title="Valid conversation", type=Conversation.Type.ASSISTANT
187+
user=self.user, team=self.team, title="Missing artifact conversation", type=Conversation.Type.ASSISTANT
141188
)
142189

143-
state = AssistantState(messages=[AssistantMessage(content="Test message", type="ai")])
190+
# Create state with an ArtifactRefMessage pointing to non-existent artifact
191+
artifact_message = ArtifactRefMessage(
192+
id=str(uuid4()),
193+
content_type=ArtifactContentType.VISUALIZATION,
194+
artifact_id="nonexistent",
195+
source=ArtifactSource.ARTIFACT,
196+
)
197+
state = AssistantState(messages=[artifact_message])
144198

145199
with patch("langgraph.graph.state.CompiledStateGraph.aget_state", new_callable=AsyncMock) as mock_get_state:
146200

@@ -157,16 +211,32 @@ class MockSnapshot:
157211
},
158212
).data
159213

160-
self.assertEqual(len(data["messages"]), 1)
161-
self.assertFalse(data["has_unsupported_content"])
214+
# The message should be filtered out
215+
self.assertEqual(len(data["messages"]), 0)
162216

163-
def test_caching_prevents_duplicate_operations(self):
164-
"""This is to test that the caching works correctly as to not incurring in unnecessary operations (We would do a DRF call per field call)."""
217+
def test_mixed_messages_with_artifacts(self):
218+
"""Test serialization with mixed message types including artifacts."""
165219
conversation = Conversation.objects.create(
166-
user=self.user, team=self.team, title="Cached conversation", type=Conversation.Type.ASSISTANT
220+
user=self.user, team=self.team, title="Mixed messages conversation", type=Conversation.Type.ASSISTANT
167221
)
168222

169-
state = AssistantState(messages=[AssistantMessage(content="Cached message", type="ai")])
223+
artifact = AgentArtifact.objects.create(
224+
name="Mixed Artifact",
225+
type=AgentArtifact.Type.VISUALIZATION,
226+
data={"query": {"kind": "TrendsQuery", "series": []}, "name": "Mixed Chart"},
227+
conversation=conversation,
228+
team=self.team,
229+
)
230+
231+
# Create state with mixed message types
232+
assistant_message = AssistantMessage(content="Hello from assistant", type="ai")
233+
artifact_message = ArtifactRefMessage(
234+
id=str(uuid4()),
235+
content_type=ArtifactContentType.VISUALIZATION,
236+
artifact_id=artifact.short_id,
237+
source=ArtifactSource.ARTIFACT,
238+
)
239+
state = AssistantState(messages=[assistant_message, artifact_message])
170240

171241
with patch("langgraph.graph.state.CompiledStateGraph.aget_state", new_callable=AsyncMock) as mock_get_state:
172242

@@ -175,19 +245,17 @@ class MockSnapshot:
175245

176246
mock_get_state.return_value = MockSnapshot()
177247

178-
serializer = ConversationSerializer(
248+
data = ConversationSerializer(
179249
conversation,
180250
context={
181251
"team": self.team,
182252
"user": self.user,
183253
},
184-
)
185-
186-
# Explicitly access both fields multiple times
187-
_ = serializer.data["messages"]
188-
_ = serializer.data["has_unsupported_content"]
189-
_ = serializer.data["messages"]
190-
_ = serializer.data["has_unsupported_content"]
254+
).data
191255

192-
# aget_state should only be called once though
193-
self.assertEqual(mock_get_state.call_count, 1)
256+
# Both messages should be included (AssistantMessage and enriched ArtifactMessage)
257+
self.assertEqual(len(data["messages"]), 2)
258+
self.assertEqual(data["messages"][0]["type"], "ai")
259+
self.assertEqual(data["messages"][0]["content"], "Hello from assistant")
260+
self.assertEqual(data["messages"][1]["type"], "ai/artifact")
261+
self.assertEqual(data["messages"][1]["content"]["name"], "Mixed Chart")

ee/hogai/artifacts/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from posthog.schema import (
2+
ArtifactSource,
3+
DocumentArtifactContent,
4+
MarkdownBlock,
5+
SessionReplayBlock,
6+
VisualizationArtifactContent,
7+
VisualizationBlock,
8+
)
9+
10+
from .manager import ArtifactManager
11+
12+
__all__ = [
13+
"ArtifactSource",
14+
"ArtifactManager",
15+
"DocumentArtifactContent",
16+
"MarkdownBlock",
17+
"SessionReplayBlock",
18+
"VisualizationArtifactContent",
19+
"VisualizationBlock",
20+
]

0 commit comments

Comments
 (0)