1818
1919import java .time .Instant ;
2020import java .util .ArrayList ;
21- import java .util .Collections ;
2221import java .util .List ;
23- import java .util .concurrent . atomic . AtomicLong ;
22+ import java .util .Map ;
2423
2524import com .datastax .oss .driver .api .core .cql .BoundStatement ;
2625import com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
2726import com .datastax .oss .driver .api .core .cql .PreparedStatement ;
2827import com .datastax .oss .driver .api .core .cql .Row ;
28+ import com .datastax .oss .driver .api .core .data .UdtValue ;
2929import com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
30- import com .datastax .oss .driver .api .querybuilder .delete .Delete ;
31- import com .datastax .oss .driver .api .querybuilder .delete .DeleteSelection ;
3230import com .datastax .oss .driver .api .querybuilder .insert .InsertInto ;
3331import com .datastax .oss .driver .api .querybuilder .insert .RegularInsert ;
3432import com .datastax .oss .driver .api .querybuilder .select .Select ;
3533import com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
34+
3635import org .springframework .ai .chat .memory .ChatMemoryRepository ;
3736import org .springframework .ai .chat .messages .AssistantMessage ;
3837import org .springframework .ai .chat .messages .Message ;
38+ import org .springframework .ai .chat .messages .MessageType ;
39+ import org .springframework .ai .chat .messages .SystemMessage ;
40+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
3941import org .springframework .ai .chat .messages .UserMessage ;
4042import org .springframework .util .Assert ;
4143
@@ -54,23 +56,17 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {
5456
5557 private final PreparedStatement allStmt ;
5658
57- private final PreparedStatement addUserStmt ;
58-
59- private final PreparedStatement addAssistantStmt ;
59+ private final PreparedStatement addStmt ;
6060
6161 private final PreparedStatement getStmt ;
6262
63- private final PreparedStatement deleteStmt ;
64-
6563 private CassandraChatMemoryRepository (CassandraChatMemoryRepositoryConfig conf ) {
6664 Assert .notNull (conf , "conf cannot be null" );
6765 this .conf = conf ;
6866 this .conf .ensureSchemaExists ();
6967 this .allStmt = prepareAllStatement ();
70- this .addUserStmt = prepareAddStmt (this .conf .userColumn );
71- this .addAssistantStmt = prepareAddStmt (this .conf .assistantColumn );
68+ this .addStmt = prepareAddStmt ();
7269 this .getStmt = prepareGetStatement ();
73- this .deleteStmt = prepareDeleteStmt ();
7470 }
7571
7672 public static CassandraChatMemoryRepository create (CassandraChatMemoryRepositoryConfig conf ) {
@@ -97,6 +93,10 @@ public List<String> findConversationIds() {
9793
9894 @ Override
9995 public List <Message > findByConversationId (String conversationId ) {
96+ return findByConversationIdWithLimit (conversationId , 1 );
97+ }
98+
99+ List <Message > findByConversationIdWithLimit (String conversationId , int limit ) {
100100 Assert .hasText (conversationId , "conversationId cannot be null or empty" );
101101
102102 List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
@@ -106,19 +106,14 @@ public List<Message> findByConversationId(String conversationId) {
106106 CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
107107 builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
108108 }
109+ builder = builder .setInt ("legacy_limit" , limit );
109110
110111 List <Message > messages = new ArrayList <>();
111112 for (Row r : this .conf .session .execute (builder .build ())) {
112- String assistant = r .getString (this .conf .assistantColumn );
113- String user = r .getString (this .conf .userColumn );
114- if (null != assistant ) {
115- messages .add (new AssistantMessage (assistant ));
116- }
117- if (null != user ) {
118- messages .add (new UserMessage (user ));
113+ for (UdtValue udt : r .getList (this .conf .messagesColumn , UdtValue .class )) {
114+ messages .add (getMessage (udt ));
119115 }
120116 }
121- Collections .reverse (messages );
122117 return messages ;
123118 }
124119
@@ -128,58 +123,49 @@ public void saveAll(String conversationId, List<Message> messages) {
128123 Assert .notNull (messages , "messages cannot be null" );
129124 Assert .noNullElements (messages , "messages cannot contain null elements" );
130125
131- final AtomicLong instantSeq = new AtomicLong (Instant .now ().toEpochMilli ());
132- messages .forEach (msg -> {
133- if (msg .getMetadata ().containsKey (CONVERSATION_TS )) {
134- msg .getMetadata ().put (CONVERSATION_TS , Instant .ofEpochMilli (instantSeq .getAndIncrement ()));
135- }
136- save (conversationId , msg );
137- });
138- }
139-
140- void save (String conversationId , Message msg ) {
141-
142- Preconditions .checkArgument (
143- !msg .getMetadata ().containsKey (CONVERSATION_TS )
144- || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
145- "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
146-
147- msg .getMetadata ().putIfAbsent (CONVERSATION_TS , Instant .now ());
148-
149- PreparedStatement stmt = getStatement (msg );
150-
126+ Instant instant = Instant .now ();
151127 List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
152- BoundStatementBuilder builder = stmt .boundStatementBuilder ();
128+ BoundStatementBuilder builder = addStmt .boundStatementBuilder ();
153129
154130 for (int k = 0 ; k < primaryKeys .size (); ++k ) {
155131 CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
156132 builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
157133 }
158134
159- Instant instant = (Instant ) msg .getMetadata ().get (CONVERSATION_TS );
135+ List <UdtValue > msgs = new ArrayList <>();
136+ for (Message msg : messages ) {
137+
138+ Preconditions .checkArgument (
139+ !msg .getMetadata ().containsKey (CONVERSATION_TS )
140+ || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
141+ "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
160142
143+ msg .getMetadata ().putIfAbsent (CONVERSATION_TS , instant );
144+
145+ UdtValue udt = this .conf .session .getMetadata ()
146+ .getKeyspace (this .conf .schema .keyspace ())
147+ .get ()
148+ .getUserDefinedType (this .conf .messageUDT )
149+ .get ()
150+ .newValue ()
151+ .setInstant (this .conf .messageUdtTimestampColumn , (Instant ) msg .getMetadata ().get (CONVERSATION_TS ))
152+ .setString (this .conf .messageUdtTypeColumn , msg .getMessageType ().name ())
153+ .setString (this .conf .messageUdtContentColumn , msg .getText ());
154+
155+ msgs .add (udt );
156+ }
161157 builder = builder .setInstant (CassandraChatMemoryRepositoryConfig .DEFAULT_EXCHANGE_ID_NAME , instant )
162- .setString ( "message " , msg . getText () );
158+ .setList ( "msgs " , msgs , UdtValue . class );
163159
164160 this .conf .session .execute (builder .build ());
165161 }
166162
167163 @ Override
168164 public void deleteByConversationId (String conversationId ) {
169- Assert .hasText (conversationId , "conversationId cannot be null or empty" );
170-
171- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
172- BoundStatementBuilder builder = this .deleteStmt .boundStatementBuilder ();
173-
174- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
175- CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
176- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
177- }
178-
179- this .conf .session .execute (builder .build ());
165+ saveAll (conversationId , List .of ());
180166 }
181167
182- private PreparedStatement prepareAddStmt (String column ) {
168+ private PreparedStatement prepareAddStmt () {
183169 RegularInsert stmt = null ;
184170 InsertInto stmtStart = QueryBuilder .insertInto (this .conf .schema .keyspace (), this .conf .schema .table ());
185171 for (var c : this .conf .schema .partitionKeys ()) {
@@ -188,7 +174,7 @@ private PreparedStatement prepareAddStmt(String column) {
188174 for (var c : this .conf .schema .clusteringKeys ()) {
189175 stmt = stmt .value (c .name (), QueryBuilder .bindMarker (c .name ()));
190176 }
191- stmt = stmt .value (column , QueryBuilder .bindMarker ("message " ));
177+ stmt = stmt .value (this . conf . messagesColumn , QueryBuilder .bindMarker ("msgs " ));
192178 return this .conf .session .prepare (stmt .build ());
193179 }
194180
@@ -214,28 +200,27 @@ private PreparedStatement prepareGetStatement() {
214200 String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
215201 stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
216202 }
203+ stmt = stmt .limit (QueryBuilder .bindMarker ("legacy_limit" ));
217204 return this .conf .session .prepare (stmt .build ());
218205 }
219206
220- private PreparedStatement prepareDeleteStmt () {
221- Delete stmt = null ;
222- DeleteSelection stmtStart = QueryBuilder .deleteFrom (this .conf .schema .keyspace (), this .conf .schema .table ());
223- for (var c : this .conf .schema .partitionKeys ()) {
224- stmt = (null != stmt ? stmt : stmtStart ).whereColumn (c .name ()).isEqualTo (QueryBuilder .bindMarker (c .name ()));
225- }
226- for (int i = 0 ; i + 1 < this .conf .schema .clusteringKeys ().size (); ++i ) {
227- String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
228- stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
207+ private Message getMessage (UdtValue udt ) {
208+ String content = udt .getString (this .conf .messageUdtContentColumn );
209+ Map <String , Object > props = Map .of (CONVERSATION_TS , udt .getInstant (this .conf .messageUdtTimestampColumn ));
210+ switch (MessageType .valueOf (udt .getString (this .conf .messageUdtTypeColumn ))) {
211+ case ASSISTANT :
212+ return new AssistantMessage (content , props );
213+ case USER :
214+ return UserMessage .builder ().text (content ).metadata (props ).build ();
215+ case SYSTEM :
216+ return SystemMessage .builder ().text (content ).metadata (props ).build ();
217+ case TOOL :
218+ // todo – persist ToolResponse somehow
219+ return new ToolResponseMessage (List .of (), props );
220+ default :
221+ throw new IllegalStateException (
222+ String .format ("unknown message type %s" , udt .getString (this .conf .messageUdtTypeColumn )));
229223 }
230- return this .conf .session .prepare (stmt .build ());
231- }
232-
233- private PreparedStatement getStatement (Message msg ) {
234- return switch (msg .getMessageType ()) {
235- case USER -> this .addUserStmt ;
236- case ASSISTANT -> this .addAssistantStmt ;
237- default -> throw new IllegalArgumentException ("Cant add type " + msg );
238- };
239224 }
240225
241226}
0 commit comments