|
22 | 22 |
|
23 | 23 | import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer; |
24 | 24 | import org.springframework.boot.jdbc.init.PlatformPlaceholderDatabaseDriverResolver; |
25 | | -import org.springframework.boot.sql.init.DatabaseInitializationMode; |
26 | 25 | import org.springframework.boot.sql.init.DatabaseInitializationSettings; |
27 | 26 | import org.springframework.util.StringUtils; |
28 | 27 |
|
29 | 28 | /** |
30 | 29 | * Performs database initialization for the JDBC Chat Memory Repository. |
31 | 30 | * |
32 | 31 | * @author Mark Pollack |
| 32 | + * @author Yanming Zhou |
33 | 33 | * @since 1.0.0 |
34 | 34 | */ |
35 | 35 | class JdbcChatMemoryRepositorySchemaInitializer extends DataSourceScriptDatabaseInitializer { |
36 | 36 |
|
37 | | - private static final String DEFAULT_SCHEMA_LOCATION = "classpath:org/springframework/ai/chat/memory/jdbc/schema-@@platform@@.sql"; |
38 | | - |
39 | 37 | JdbcChatMemoryRepositorySchemaInitializer(DataSource dataSource, JdbcChatMemoryRepositoryProperties properties) { |
40 | 38 | super(dataSource, getSettings(dataSource, properties)); |
41 | 39 | } |
42 | 40 |
|
43 | 41 | static DatabaseInitializationSettings getSettings(DataSource dataSource, |
44 | 42 | JdbcChatMemoryRepositoryProperties properties) { |
45 | 43 | var settings = new DatabaseInitializationSettings(); |
46 | | - |
47 | | - // Determine schema locations |
48 | | - String schemaProp = properties.getSchema(); |
49 | | - List<String> schemaLocations; |
50 | | - PlatformPlaceholderDatabaseDriverResolver resolver = new PlatformPlaceholderDatabaseDriverResolver(); |
51 | | - try { |
52 | | - String url = dataSource.getConnection().getMetaData().getURL().toLowerCase(); |
53 | | - if (url.contains("hsqldb")) { |
54 | | - schemaLocations = List.of("classpath:org/springframework/ai/chat/memory/jdbc/schema-hsqldb.sql"); |
55 | | - } |
56 | | - else if (StringUtils.hasText(schemaProp)) { |
57 | | - schemaLocations = resolver.resolveAll(dataSource, schemaProp); |
58 | | - } |
59 | | - else { |
60 | | - schemaLocations = resolver.resolveAll(dataSource, DEFAULT_SCHEMA_LOCATION); |
61 | | - } |
62 | | - } |
63 | | - catch (Exception e) { |
64 | | - // fallback to default |
65 | | - if (StringUtils.hasText(schemaProp)) { |
66 | | - schemaLocations = resolver.resolveAll(dataSource, schemaProp); |
67 | | - } |
68 | | - else { |
69 | | - schemaLocations = resolver.resolveAll(dataSource, DEFAULT_SCHEMA_LOCATION); |
70 | | - } |
71 | | - } |
72 | | - settings.setSchemaLocations(schemaLocations); |
73 | | - |
74 | | - // Determine initialization mode |
75 | | - JdbcChatMemoryRepositoryProperties.DatabaseInitializationMode init = properties.getInitializeSchema(); |
76 | | - DatabaseInitializationMode mode; |
77 | | - if (JdbcChatMemoryRepositoryProperties.DatabaseInitializationMode.ALWAYS.equals(init)) { |
78 | | - mode = DatabaseInitializationMode.ALWAYS; |
79 | | - } |
80 | | - else if (JdbcChatMemoryRepositoryProperties.DatabaseInitializationMode.NEVER.equals(init)) { |
81 | | - mode = DatabaseInitializationMode.NEVER; |
82 | | - } |
83 | | - else { |
84 | | - // embedded or default |
85 | | - mode = DatabaseInitializationMode.EMBEDDED; |
86 | | - } |
87 | | - settings.setMode(mode); |
| 44 | + settings.setSchemaLocations(resolveSchemaLocations(dataSource, properties)); |
| 45 | + settings.setMode(properties.getInitializeSchema()); |
88 | 46 | settings.setContinueOnError(true); |
89 | 47 | return settings; |
90 | 48 | } |
91 | 49 |
|
| 50 | + private static List<String> resolveSchemaLocations(DataSource dataSource, |
| 51 | + JdbcChatMemoryRepositoryProperties properties) { |
| 52 | + PlatformPlaceholderDatabaseDriverResolver platformResolver = new PlatformPlaceholderDatabaseDriverResolver(); |
| 53 | + if (StringUtils.hasText(properties.getPlatform())) { |
| 54 | + return platformResolver.resolveAll(properties.getPlatform(), properties.getSchema()); |
| 55 | + } |
| 56 | + return platformResolver.resolveAll(dataSource, properties.getSchema()); |
| 57 | + } |
| 58 | + |
92 | 59 | } |
0 commit comments