@@ -73,7 +73,7 @@ public function __construct(
7373 $ this ->objectMapper = new ObjectMapper ($ schemaManager );
7474 }
7575
76- public function send (int $ agentId , AgentInput $ message , ContextInterface $ context ): AgentOutput
76+ public function send (int $ agentId , int $ parentId , AgentInput $ message , ContextInterface $ context ): AgentOutput
7777 {
7878 $ row = $ this ->agentTable ->findOneByTenantAndId ($ context ->getTenantId (), $ context ->getUser ()->getCategoryId (), $ agentId );
7979 if (!$ row instanceof Table \Generated \AgentRow) {
@@ -93,11 +93,13 @@ public function send(int $agentId, AgentInput $message, ContextInterface $contex
9393 $ messages = new MessageBag ();
9494 $ messages ->add (Message::forSystem ($ row ->getIntroduction ()));
9595
96- $ messages = $ this ->loadPreviousMessages ($ agentId , $ context ->getUser ()->getId (), $ messages );
96+ if ($ parentId > 0 ) {
97+ $ messages = $ this ->loadPreviousMessages ($ agentId , $ context ->getUser ()->getId (), $ parentId , $ messages );
98+ }
9799
98100 $ userMessages = $ this ->messageUnserializer ->unserialize ($ input );
99101
100- $ this ->persistUserMessages ($ agentId , $ context ->getUser ()->getId (), $ userMessages );
102+ $ parentId = $ this ->persistUserMessages ($ agentId , $ context ->getUser ()->getId (), $ parentId , $ userMessages );
101103
102104 $ messages = $ messages ->merge ($ userMessages );
103105
@@ -126,7 +128,7 @@ public function send(int $agentId, AgentInput $message, ContextInterface $contex
126128 $ output = $ this ->resultSerializer ->serialize ($ result );
127129 }
128130
129- $ messageRow = $ this ->messageTable ->addAssistantMessage ($ row ->getId (), $ context ->getUser ()->getId (), $ output );
131+ $ messageRow = $ this ->messageTable ->addAssistantMessage ($ row ->getId (), $ context ->getUser ()->getId (), $ parentId , $ output );
130132
131133 $ this ->agentTable ->commit ();
132134
@@ -145,11 +147,16 @@ public function send(int $agentId, AgentInput $message, ContextInterface $contex
145147 }
146148 }
147149
148- private function loadPreviousMessages (int $ agentId , int $ userId , MessageBag $ messages ): MessageBag
150+ private function loadPreviousMessages (int $ agentId , int $ userId , int $ parentId , MessageBag $ messages ): MessageBag
149151 {
152+ $ idCondition = Condition::withOr ();
153+ $ idCondition ->equals (Table \Generated \AgentMessageColumn::ID , $ parentId );
154+ $ idCondition ->equals (Table \Generated \AgentMessageColumn::PARENT_ID , $ parentId );
155+
150156 $ condition = Condition::withAnd ();
151157 $ condition ->equals (Table \Generated \AgentMessageColumn::AGENT_ID , $ agentId );
152158 $ condition ->equals (Table \Generated \AgentMessageColumn::USER_ID , $ userId );
159+ $ condition ->add ($ idCondition );
153160
154161 $ count = $ this ->messageTable ->getCount ($ condition );
155162 $ startIndex = max (0 , $ count - self ::CONTEXT_MESSAGES_LENGTH );
@@ -178,13 +185,19 @@ private function loadPreviousMessages(int $agentId, int $userId, MessageBag $mes
178185 return $ messages ;
179186 }
180187
181- private function persistUserMessages (int $ agentId , int $ userId , MessageBag $ userMessages ): void
188+ private function persistUserMessages (int $ agentId , int $ userId , int $ parentId , MessageBag $ userMessages ): int
182189 {
183190 foreach ($ userMessages as $ userMessage ) {
184191 foreach ($ this ->messageSerializer ->serialize ($ userMessage ) as $ content ) {
185- $ this ->messageTable ->addUserMessage ($ agentId , $ userId , $ content );
192+ $ message = $ this ->messageTable ->addUserMessage ($ agentId , $ userId , $ parentId , $ content );
193+
194+ if ($ parentId === 0 ) {
195+ $ parentId = $ message ->getId ();
196+ }
186197 }
187198 }
199+
200+ return $ parentId ;
188201 }
189202
190203 private function getResponseSchema (Table \Generated \AgentRow $ row ): ?array
0 commit comments