Skip to content

Commit 0f1663d

Browse files
sunyuhan1998Willam2004
authored andcommitted
spring-projectsGH-4289: Optimized JdbcChatMemoryRepositoryDialect#from
* Rewrote the method to use `org.springframework.jdbc.support.JdbcUtils#extractDatabaseMetaData` for extracting database metadata from the `dataSource`, and obtain the database vendor name from the JDBC driver, improving accuracy and robustness. * Enhanced exception handling: instead of silently ignoring exceptions, it now explicitly reports encountered issues, helping users identify problems and select the appropriate dialect more easily; Fixes spring-projects#4289 Signed-off-by: Sun Yuhan <[email protected]> Signed-off-by: 家娃 <[email protected]>
1 parent 3db19a0 commit 0f1663d

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryDialect.java

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,22 @@
1616

1717
package org.springframework.ai.chat.memory.repository.jdbc;
1818

19-
import java.sql.Connection;
19+
import java.sql.DatabaseMetaData;
2020

2121
import javax.sql.DataSource;
2222

23+
import org.slf4j.Logger;
24+
import org.slf4j.LoggerFactory;
25+
26+
import org.springframework.jdbc.support.JdbcUtils;
27+
2328
/**
2429
* Abstraction for database-specific SQL for chat memory repository.
2530
*/
2631
public interface JdbcChatMemoryRepositoryDialect {
2732

33+
Logger logger = LoggerFactory.getLogger(JdbcChatMemoryRepositoryDialect.class);
34+
2835
/**
2936
* Returns the SQL to fetch messages for a conversation, ordered by timestamp, with
3037
* limit.
@@ -51,32 +58,29 @@ public interface JdbcChatMemoryRepositoryDialect {
5158
*/
5259

5360
/**
54-
* Detects the dialect from the DataSource or JDBC URL.
61+
* Detects the dialect from the DataSource.
5562
*/
5663
static JdbcChatMemoryRepositoryDialect from(DataSource dataSource) {
57-
// Simple detection (could be improved)
58-
try (Connection connection = dataSource.getConnection()) {
59-
String url = connection.getMetaData().getURL().toLowerCase();
60-
if (url.contains("postgresql")) {
61-
return new PostgresChatMemoryRepositoryDialect();
62-
}
63-
if (url.contains("mysql")) {
64-
return new MysqlChatMemoryRepositoryDialect();
65-
}
66-
if (url.contains("mariadb")) {
67-
return new MysqlChatMemoryRepositoryDialect();
68-
}
69-
if (url.contains("sqlserver")) {
70-
return new SqlServerChatMemoryRepositoryDialect();
71-
}
72-
if (url.contains("hsqldb")) {
73-
return new HsqldbChatMemoryRepositoryDialect();
74-
}
75-
// Add more as needed
64+
String productName = null;
65+
try {
66+
productName = JdbcUtils.extractDatabaseMetaData(dataSource, DatabaseMetaData::getDatabaseProductName);
67+
}
68+
catch (Exception e) {
69+
logger.warn("Due to failure in establishing JDBC connection or parsing metadata, the JDBC database vendor "
70+
+ "could not be determined", e);
7671
}
77-
catch (Exception ignored) {
72+
if (productName == null || productName.trim().isEmpty()) {
73+
logger.warn("Database product name is null or empty, defaulting to Postgres dialect.");
74+
return new PostgresChatMemoryRepositoryDialect();
7875
}
79-
return new PostgresChatMemoryRepositoryDialect(); // default
76+
return switch (productName) {
77+
case "PostgreSQL" -> new PostgresChatMemoryRepositoryDialect();
78+
case "MySQL", "MariaDB" -> new MysqlChatMemoryRepositoryDialect();
79+
case "Microsoft SQL Server" -> new SqlServerChatMemoryRepositoryDialect();
80+
case "HSQL Database Engine" -> new HsqldbChatMemoryRepositoryDialect();
81+
default -> // Add more as needed
82+
new PostgresChatMemoryRepositoryDialect();
83+
};
8084
}
8185

8286
}

0 commit comments

Comments
 (0)