1+ from uuid import uuid4
2+
13from posthog .test .base import APIBaseTest
24from unittest .mock import AsyncMock , patch
35
4- from posthog .schema import AssistantMessage , AssistantToolCallMessage , ContextMessage
6+ from posthog .schema import (
7+ ArtifactContentType ,
8+ ArtifactSource ,
9+ AssistantMessage ,
10+ AssistantToolCallMessage ,
11+ ContextMessage ,
12+ )
513
614from ee .hogai .api .serializers import ConversationSerializer
715from ee .hogai .chat_agent import AssistantGraph
816from ee .hogai .utils .types import AssistantState
9- from ee .models .assistant import Conversation
17+ from ee .hogai .utils .types .base import ArtifactRefMessage
18+ from ee .models .assistant import AgentArtifact , Conversation
1019
1120
1221class TestConversationSerializers (APIBaseTest ):
@@ -67,8 +76,8 @@ class MockSnapshot:
6776 # Third message should be the AssistantToolCallMessage without UI payload
6877 self .assertEqual (filtered_messages [2 ]["ui_payload" ], None )
6978
70- def test_get_messages_handles_validation_errors (self ):
71- """Gracefully fall back to an empty list when the stored state fails validation."""
79+ def test_get_messages_handles_validation_errors_and_sets_unsupported_content (self ):
80+ """Gracefully fall back to an empty list when the stored state fails validation, and set has_unsupported_content ."""
7281 conversation = Conversation .objects .create (
7382 user = self .user , team = self .team , title = "Conversation with invalid state" , type = Conversation .Type .ASSISTANT
7483 )
@@ -88,40 +97,74 @@ def test_get_messages_handles_validation_errors(self):
8897 ).data
8998
9099 self .assertEqual (data ["messages" ], [])
100+ self .assertTrue (data ["has_unsupported_content" ])
91101
92- def test_has_unsupported_content_on_validation_error (self ):
93- """When validation fails, has_unsupported_content should be True ."""
102+ def test_caching_prevents_duplicate_operations (self ):
103+ """This is to test that the caching works correctly as to not incurring in unnecessary operations (We would do a DRF call per field call) ."""
94104 conversation = Conversation .objects .create (
95- user = self .user ,
96- team = self .team ,
97- title = "Conversation with schema mismatch" ,
98- type = Conversation .Type .ASSISTANT ,
105+ user = self .user , team = self .team , title = "Cached conversation" , type = Conversation .Type .ASSISTANT
99106 )
100107
101- invalid_snapshot = type ( "Snapshot" , (), { "values" : { "messages" : [{ "invalid" : "schema" }]}})( )
108+ state = AssistantState ( messages = [ AssistantMessage ( content = "Cached message" , type = "ai" )] )
102109
103110 with patch ("langgraph.graph.state.CompiledStateGraph.aget_state" , new_callable = AsyncMock ) as mock_get_state :
104- mock_get_state .return_value = invalid_snapshot
105111
106- data = ConversationSerializer (
112+ class MockSnapshot :
113+ values = state .model_dump ()
114+
115+ mock_get_state .return_value = MockSnapshot ()
116+
117+ serializer = ConversationSerializer (
107118 conversation ,
108119 context = {
109120 "team" : self .team ,
110121 "user" : self .user ,
111122 },
112- ). data
123+ )
113124
114- self .assertEqual (data ["messages" ], [])
115- self .assertTrue (data ["has_unsupported_content" ])
125+ # Explicitly access both fields multiple times
126+ _ = serializer .data ["messages" ]
127+ _ = serializer .data ["has_unsupported_content" ]
128+ _ = serializer .data ["messages" ]
129+ _ = serializer .data ["has_unsupported_content" ]
116130
117- def test_has_unsupported_content_on_other_errors (self ):
118- """On non-validation errors, has_unsupported_content should be False."""
131+ # aget_state should only be called once though
132+ self .assertEqual (mock_get_state .call_count , 1 )
133+
134+
135+ class TestConversationSerializerArtifactEnrichment (APIBaseTest ):
136+ """Test artifact enrichment functionality in the serializer."""
137+
138+ def test_artifact_ref_message_enriched_in_response (self ):
139+ """Test that ArtifactRefMessage is enriched with content from database artifact."""
119140 conversation = Conversation .objects .create (
120- user = self .user , team = self .team , title = "Conversation with graph error " , type = Conversation .Type .ASSISTANT
141+ user = self .user , team = self .team , title = "Artifact test conversation " , type = Conversation .Type .ASSISTANT
121142 )
122143
144+ # Create an artifact in the database
145+ artifact = AgentArtifact .objects .create (
146+ name = "Test Artifact" ,
147+ type = AgentArtifact .Type .VISUALIZATION ,
148+ data = {"query" : {"kind" : "TrendsQuery" , "series" : []}, "name" : "Chart Name" },
149+ conversation = conversation ,
150+ team = self .team ,
151+ )
152+
153+ # Create state with an ArtifactRefMessage
154+ artifact_message = ArtifactRefMessage (
155+ id = str (uuid4 ()),
156+ content_type = ArtifactContentType .VISUALIZATION ,
157+ artifact_id = artifact .short_id ,
158+ source = ArtifactSource .ARTIFACT ,
159+ )
160+ state = AssistantState (messages = [artifact_message ])
161+
123162 with patch ("langgraph.graph.state.CompiledStateGraph.aget_state" , new_callable = AsyncMock ) as mock_get_state :
124- mock_get_state .side_effect = RuntimeError ("Graph compilation failed" )
163+
164+ class MockSnapshot :
165+ values = state .model_dump ()
166+
167+ mock_get_state .return_value = MockSnapshot ()
125168
126169 data = ConversationSerializer (
127170 conversation ,
@@ -131,16 +174,27 @@ def test_has_unsupported_content_on_other_errors(self):
131174 },
132175 ).data
133176
134- self .assertEqual (data ["messages" ], [])
135- self .assertFalse (data ["has_unsupported_content" ])
177+ # The message should be enriched as an ArtifactMessage
178+ self .assertEqual (len (data ["messages" ]), 1 )
179+ enriched_msg = data ["messages" ][0 ]
180+ self .assertEqual (enriched_msg ["type" ], "ai/artifact" )
181+ self .assertEqual (enriched_msg ["artifact_id" ], artifact .short_id )
182+ self .assertEqual (enriched_msg ["content" ]["name" ], "Chart Name" )
136183
137- def test_has_unsupported_content_on_success (self ):
138- """On successful message fetch, has_unsupported_content should be False ."""
184+ def test_artifact_ref_message_filtered_when_not_found (self ):
185+ """Test that ArtifactRefMessage is filtered out when artifact not found in database ."""
139186 conversation = Conversation .objects .create (
140- user = self .user , team = self .team , title = "Valid conversation" , type = Conversation .Type .ASSISTANT
187+ user = self .user , team = self .team , title = "Missing artifact conversation" , type = Conversation .Type .ASSISTANT
141188 )
142189
143- state = AssistantState (messages = [AssistantMessage (content = "Test message" , type = "ai" )])
190+ # Create state with an ArtifactRefMessage pointing to non-existent artifact
191+ artifact_message = ArtifactRefMessage (
192+ id = str (uuid4 ()),
193+ content_type = ArtifactContentType .VISUALIZATION ,
194+ artifact_id = "nonexistent" ,
195+ source = ArtifactSource .ARTIFACT ,
196+ )
197+ state = AssistantState (messages = [artifact_message ])
144198
145199 with patch ("langgraph.graph.state.CompiledStateGraph.aget_state" , new_callable = AsyncMock ) as mock_get_state :
146200
@@ -157,16 +211,32 @@ class MockSnapshot:
157211 },
158212 ).data
159213
160- self . assertEqual ( len ( data [ "messages" ]), 1 )
161- self .assertFalse ( data ["has_unsupported_content" ] )
214+ # The message should be filtered out
215+ self .assertEqual ( len ( data ["messages" ]), 0 )
162216
163- def test_caching_prevents_duplicate_operations (self ):
164- """This is to test that the caching works correctly as to not incurring in unnecessary operations (We would do a DRF call per field call) ."""
217+ def test_mixed_messages_with_artifacts (self ):
218+ """Test serialization with mixed message types including artifacts ."""
165219 conversation = Conversation .objects .create (
166- user = self .user , team = self .team , title = "Cached conversation" , type = Conversation .Type .ASSISTANT
220+ user = self .user , team = self .team , title = "Mixed messages conversation" , type = Conversation .Type .ASSISTANT
167221 )
168222
169- state = AssistantState (messages = [AssistantMessage (content = "Cached message" , type = "ai" )])
223+ artifact = AgentArtifact .objects .create (
224+ name = "Mixed Artifact" ,
225+ type = AgentArtifact .Type .VISUALIZATION ,
226+ data = {"query" : {"kind" : "TrendsQuery" , "series" : []}, "name" : "Mixed Chart" },
227+ conversation = conversation ,
228+ team = self .team ,
229+ )
230+
231+ # Create state with mixed message types
232+ assistant_message = AssistantMessage (content = "Hello from assistant" , type = "ai" )
233+ artifact_message = ArtifactRefMessage (
234+ id = str (uuid4 ()),
235+ content_type = ArtifactContentType .VISUALIZATION ,
236+ artifact_id = artifact .short_id ,
237+ source = ArtifactSource .ARTIFACT ,
238+ )
239+ state = AssistantState (messages = [assistant_message , artifact_message ])
170240
171241 with patch ("langgraph.graph.state.CompiledStateGraph.aget_state" , new_callable = AsyncMock ) as mock_get_state :
172242
@@ -175,19 +245,17 @@ class MockSnapshot:
175245
176246 mock_get_state .return_value = MockSnapshot ()
177247
178- serializer = ConversationSerializer (
248+ data = ConversationSerializer (
179249 conversation ,
180250 context = {
181251 "team" : self .team ,
182252 "user" : self .user ,
183253 },
184- )
185-
186- # Explicitly access both fields multiple times
187- _ = serializer .data ["messages" ]
188- _ = serializer .data ["has_unsupported_content" ]
189- _ = serializer .data ["messages" ]
190- _ = serializer .data ["has_unsupported_content" ]
254+ ).data
191255
192- # aget_state should only be called once though
193- self .assertEqual (mock_get_state .call_count , 1 )
256+ # Both messages should be included (AssistantMessage and enriched ArtifactMessage)
257+ self .assertEqual (len (data ["messages" ]), 2 )
258+ self .assertEqual (data ["messages" ][0 ]["type" ], "ai" )
259+ self .assertEqual (data ["messages" ][0 ]["content" ], "Hello from assistant" )
260+ self .assertEqual (data ["messages" ][1 ]["type" ], "ai/artifact" )
261+ self .assertEqual (data ["messages" ][1 ]["content" ]["name" ], "Mixed Chart" )
0 commit comments