Skip to content

Commit af61678

Browse files
committed
feat: JDBC implementation of ChatMemory
Signed-off-by: leijendary <[email protected]> feat: JDBC implementation of ChatMemory Signed-off-by: leijendary <[email protected]> feat: JDBC implementation of ChatMemory Signed-off-by: leijendary <[email protected]>
1 parent 822576b commit af61678

File tree

26 files changed

+1073
-6
lines changed

26 files changed

+1073
-6
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ One way to run integration tests on part of the code is to first do a quick comp
101101
```shell
102102
./mvnw clean install -DskipTests -Dmaven.javadoc.skip=true
103103
```
104-
Then run the integration test for a specifi module using the `-pl` option
104+
Then run the integration test for a specific module using the `-pl` option
105105
```shell
106-
./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure
106+
./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure
107107
```
108108

109109
### Documentation
@@ -134,4 +134,4 @@ To build with checkstyles enabled.
134134
Checkstyles are currently disabled, but you can enable them by doing the following:
135135
```shell
136136
./mvnw clean package -DskipTests -Ddisable.checks=false
137-
```
137+
```
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatclient.html#_chat_memory)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<!--
3+
~ Copyright 2023-2024 the original author or authors.
4+
~
5+
~ Licensed under the Apache License, Version 2.0 (the "License");
6+
~ you may not use this file except in compliance with the License.
7+
~ You may obtain a copy of the License at
8+
~
9+
~ https://www.apache.org/licenses/LICENSE-2.0
10+
~
11+
~ Unless required by applicable law or agreed to in writing, software
12+
~ distributed under the License is distributed on an "AS IS" BASIS,
13+
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
~ See the License for the specific language governing permissions and
15+
~ limitations under the License.
16+
-->
17+
18+
<project xmlns="http://maven.apache.org/POM/4.0.0"
19+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
20+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
21+
<modelVersion>4.0.0</modelVersion>
22+
<parent>
23+
<groupId>org.springframework.ai</groupId>
24+
<artifactId>spring-ai</artifactId>
25+
<version>1.0.0-SNAPSHOT</version>
26+
<relativePath>../../pom.xml</relativePath>
27+
</parent>
28+
<artifactId>spring-ai-chat-memory-jdbc</artifactId>
29+
<packaging>jar</packaging>
30+
<name>Spring AI Chat Memory JDBC</name>
31+
<description>Spring AI Chat Memory implementation with JDBC</description>
32+
<url>https://github.com/spring-projects/spring-ai</url>
33+
34+
<scm>
35+
<url>https://github.com/spring-projects/spring-ai</url>
36+
<connection>git://github.com/spring-projects/spring-ai.git</connection>
37+
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
38+
</scm>
39+
40+
<properties>
41+
<maven.compiler.source>17</maven.compiler.source>
42+
<maven.compiler.target>17</maven.compiler.target>
43+
</properties>
44+
45+
<dependencies>
46+
<dependency>
47+
<groupId>org.springframework.ai</groupId>
48+
<artifactId>spring-ai-core</artifactId>
49+
<version>${project.parent.version}</version>
50+
</dependency>
51+
52+
<dependency>
53+
<groupId>com.zaxxer</groupId>
54+
<artifactId>HikariCP</artifactId>
55+
</dependency>
56+
57+
<dependency>
58+
<groupId>org.springframework</groupId>
59+
<artifactId>spring-jdbc</artifactId>
60+
</dependency>
61+
62+
<dependency>
63+
<groupId>org.postgresql</groupId>
64+
<artifactId>postgresql</artifactId>
65+
<version>${postgresql.version}</version>
66+
<optional>true</optional>
67+
</dependency>
68+
69+
<dependency>
70+
<groupId>org.mariadb.jdbc</groupId>
71+
<artifactId>mariadb-java-client</artifactId>
72+
<version>${mariadb.version}</version>
73+
<optional>true</optional>
74+
</dependency>
75+
76+
<!-- TESTING -->
77+
<dependency>
78+
<groupId>org.springframework.boot</groupId>
79+
<artifactId>spring-boot-starter-test</artifactId>
80+
<scope>test</scope>
81+
</dependency>
82+
83+
<dependency>
84+
<groupId>org.testcontainers</groupId>
85+
<artifactId>testcontainers</artifactId>
86+
<scope>test</scope>
87+
</dependency>
88+
89+
<dependency>
90+
<groupId>org.testcontainers</groupId>
91+
<artifactId>postgresql</artifactId>
92+
<scope>test</scope>
93+
</dependency>
94+
95+
<dependency>
96+
<groupId>org.testcontainers</groupId>
97+
<artifactId>mariadb</artifactId>
98+
<scope>test</scope>
99+
</dependency>
100+
101+
<dependency>
102+
<groupId>org.testcontainers</groupId>
103+
<artifactId>junit-jupiter</artifactId>
104+
<scope>test</scope>
105+
</dependency>
106+
</dependencies>
107+
</project>
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.jdbc;
18+
19+
import java.sql.PreparedStatement;
20+
import java.sql.ResultSet;
21+
import java.sql.SQLException;
22+
import java.util.List;
23+
24+
import org.springframework.ai.chat.memory.ChatMemory;
25+
import org.springframework.ai.chat.messages.AssistantMessage;
26+
import org.springframework.ai.chat.messages.Message;
27+
import org.springframework.ai.chat.messages.MessageType;
28+
import org.springframework.ai.chat.messages.UserMessage;
29+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
30+
import org.springframework.jdbc.core.JdbcTemplate;
31+
import org.springframework.jdbc.core.RowMapper;
32+
33+
/**
34+
* An implementation of {@link ChatMemory} for JDBC. Creating an instance of
35+
* JdbcChatMemory example:
36+
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
37+
*
38+
* @author Jonathan Leijendekker
39+
* @since 1.0.0
40+
*/
41+
public class JdbcChatMemory implements ChatMemory {
42+
43+
private static final String QUERY_ADD = """
44+
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";
45+
46+
private static final String QUERY_GET = """
47+
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";
48+
49+
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
50+
51+
private final JdbcTemplate jdbcTemplate;
52+
53+
public JdbcChatMemory(JdbcChatMemoryConfig config) {
54+
this.jdbcTemplate = config.getJdbcTemplate();
55+
}
56+
57+
public static JdbcChatMemory create(JdbcChatMemoryConfig config) {
58+
return new JdbcChatMemory(config);
59+
}
60+
61+
@Override
62+
public void add(String conversationId, List<Message> messages) {
63+
this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages));
64+
}
65+
66+
@Override
67+
public List<Message> get(String conversationId, int lastN) {
68+
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
69+
}
70+
71+
@Override
72+
public void clear(String conversationId) {
73+
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
74+
}
75+
76+
private record AddBatchPreparedStatement(String conversationId,
77+
List<Message> messages) implements BatchPreparedStatementSetter {
78+
@Override
79+
public void setValues(PreparedStatement ps, int i) throws SQLException {
80+
var message = this.messages.get(i);
81+
82+
ps.setString(1, this.conversationId);
83+
ps.setString(2, message.getText());
84+
ps.setString(3, message.getMessageType().name());
85+
}
86+
87+
@Override
88+
public int getBatchSize() {
89+
return this.messages.size();
90+
}
91+
}
92+
93+
private static class MessageRowMapper implements RowMapper<Message> {
94+
95+
@Override
96+
public Message mapRow(ResultSet rs, int i) throws SQLException {
97+
var content = rs.getString(1);
98+
var type = MessageType.valueOf(rs.getString(2));
99+
100+
return switch (type) {
101+
case USER -> new UserMessage(content);
102+
case ASSISTANT -> new AssistantMessage(content);
103+
default -> null;
104+
};
105+
}
106+
107+
}
108+
109+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.jdbc;
18+
19+
import org.springframework.jdbc.core.JdbcTemplate;
20+
import org.springframework.util.Assert;
21+
22+
/**
23+
* Configuration for {@link JdbcChatMemory}.
24+
*
25+
* @author Jonathan Leijendekker
26+
* @since 1.0.0
27+
*/
28+
public final class JdbcChatMemoryConfig {
29+
30+
private final JdbcTemplate jdbcTemplate;
31+
32+
private JdbcChatMemoryConfig(Builder builder) {
33+
this.jdbcTemplate = builder.jdbcTemplate;
34+
}
35+
36+
public static Builder builder() {
37+
return new Builder();
38+
}
39+
40+
JdbcTemplate getJdbcTemplate() {
41+
return this.jdbcTemplate;
42+
}
43+
44+
public static final class Builder {
45+
46+
private JdbcTemplate jdbcTemplate;
47+
48+
private Builder() {
49+
}
50+
51+
public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) {
52+
Assert.notNull(jdbcTemplate, "jdbc template must not be null");
53+
54+
this.jdbcTemplate = jdbcTemplate;
55+
return this;
56+
}
57+
58+
public JdbcChatMemoryConfig build() {
59+
Assert.notNull(this.jdbcTemplate, "jdbc template must not be null");
60+
61+
return new JdbcChatMemoryConfig(this);
62+
}
63+
64+
}
65+
66+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.springframework.ai.chat.memory.jdbc.aot.hint;
2+
3+
import javax.sql.DataSource;
4+
5+
import org.springframework.aot.hint.MemberCategory;
6+
import org.springframework.aot.hint.RuntimeHints;
7+
import org.springframework.aot.hint.RuntimeHintsRegistrar;
8+
9+
/**
10+
* A {@link RuntimeHintsRegistrar} for JDBC Chat Memory hints
11+
*
12+
* @author Jonathan Leijendekker
13+
*/
14+
class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar {
15+
16+
@Override
17+
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
18+
hints.reflection()
19+
.registerType(DataSource.class, (hint) -> hint.withMembers(MemberCategory.INVOKE_DECLARED_METHODS));
20+
21+
hints.resources()
22+
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-drop-mariadb.sql")
23+
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-drop-postgresql.sql")
24+
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql")
25+
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql");
26+
}
27+
28+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
org.springframework.aot.hint.RuntimeHintsRegistrar=\
2+
org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRuntimeHints
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DROP TABLE IF EXISTS ai_chat_memory;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DROP TABLE IF EXISTS ai_chat_memory;
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
CREATE TABLE IF NOT EXISTS ai_chat_memory (
2+
conversation_id VARCHAR(36) NOT NULL,
3+
content TEXT NOT NULL,
4+
type VARCHAR(10) NOT NULL,
5+
`timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
6+
CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT'))
7+
);
8+
9+
CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx
10+
ON ai_chat_memory(conversation_id, `timestamp` DESC);

0 commit comments

Comments
 (0)