Skip to content

Commit f3f411a

Browse files
markpollackleijendary
authored andcommitted
Add createExtension option for PostgresML config
The PostgresMLEmbeddingModel autoconfiguration previously always executed "CREATE EXTENSION IF NOT EXISTS pgml" on startup. This could cause issues for users without superuser privileges or those who manage extensions through other means. Added a new configuration property 'createExtension' (default false) to make this behavior optional. Users can now explicitly enable extension creation when needed. Updated documentation to explain the new configuration option and its implications for deployment. Signed-off-by: leijendary <[email protected]>
1 parent cb83a7e commit f3f411a

File tree

5 files changed

+57
-25
lines changed

5 files changed

+57
-25
lines changed

models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,27 @@ public class PostgresMlEmbeddingModel extends AbstractEmbeddingModel implements
5555

5656
private final JdbcTemplate jdbcTemplate;
5757

58+
private final boolean createExtension;
59+
5860
/**
5961
* a constructor
6062
* @param jdbcTemplate JdbcTemplate
6163
*/
6264
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate) {
63-
this(jdbcTemplate, PostgresMlEmbeddingOptions.builder().build());
65+
this(jdbcTemplate, PostgresMlEmbeddingOptions.builder().build(), false);
66+
}
67+
68+
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options) {
69+
this(jdbcTemplate, options, false);
6470
}
6571

6672
/**
6773
* a PostgresMlEmbeddingModel constructor
6874
* @param jdbcTemplate JdbcTemplate to use to interact with the database.
6975
* @param options PostgresMlEmbeddingOptions to configure the client.
7076
*/
71-
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options) {
77+
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options,
78+
boolean createExtension) {
7279
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
7380
Assert.notNull(options, "options must not be null.");
7481
Assert.notNull(options.getTransformer(), "transformer must not be null.");
@@ -78,6 +85,7 @@ public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOp
7885

7986
this.jdbcTemplate = jdbcTemplate;
8087
this.defaultOptions = options;
88+
this.createExtension = createExtension;
8189
}
8290

8391
/**
@@ -99,7 +107,7 @@ public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, String transformer) {
99107
*/
100108
@Deprecated(since = "0.8.0", forRemoval = true)
101109
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType) {
102-
this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED);
110+
this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED, false);
103111
}
104112

105113
/**
@@ -112,7 +120,7 @@ public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, String transformer, V
112120
*/
113121
@Deprecated(since = "0.8.0", forRemoval = true)
114122
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType,
115-
Map<String, Object> kwargs, MetadataMode metadataMode) {
123+
Map<String, Object> kwargs, MetadataMode metadataMode, boolean createExtension) {
116124
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
117125
Assert.notNull(transformer, "transformer must not be null.");
118126
Assert.notNull(vectorType, "vectorType must not be null.");
@@ -127,6 +135,7 @@ public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, String transformer, V
127135
.withMetadataMode(metadataMode)
128136
.withKwargs(ModelOptionsUtils.toJsonString(kwargs))
129137
.build();
138+
this.createExtension = createExtension;
130139
}
131140

132141
@SuppressWarnings("null")
@@ -202,6 +211,9 @@ PostgresMlEmbeddingOptions mergeOptions(EmbeddingOptions requestOptions) {
202211

203212
@Override
204213
public void afterPropertiesSet() {
214+
if (!this.createExtension) {
215+
return;
216+
}
205217
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pgml");
206218
if (StringUtils.hasText(this.defaultOptions.getVectorType().extensionName)) {
207219
this.jdbcTemplate

models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,12 @@
1616

1717
package org.springframework.ai.postgresml;
1818

19-
import java.time.Duration;
20-
import java.time.temporal.ChronoUnit;
21-
import java.util.List;
22-
import java.util.Map;
23-
2419
import org.assertj.core.api.Assertions;
25-
import org.junit.jupiter.api.AfterEach;
20+
import org.junit.jupiter.api.BeforeEach;
2621
import org.junit.jupiter.api.Disabled;
2722
import org.junit.jupiter.api.Test;
2823
import org.junit.jupiter.params.ParameterizedTest;
2924
import org.junit.jupiter.params.provider.ValueSource;
30-
import org.testcontainers.containers.PostgreSQLContainer;
31-
import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy;
32-
import org.testcontainers.junit.jupiter.Container;
33-
import org.testcontainers.junit.jupiter.Testcontainers;
34-
import org.testcontainers.utility.DockerImageName;
35-
3625
import org.springframework.ai.document.Document;
3726
import org.springframework.ai.document.MetadataMode;
3827
import org.springframework.ai.embedding.EmbeddingOptions;
@@ -46,6 +35,16 @@
4635
import org.springframework.boot.test.autoconfigure.jdbc.JdbcTest;
4736
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
4837
import org.springframework.jdbc.core.JdbcTemplate;
38+
import org.testcontainers.containers.PostgreSQLContainer;
39+
import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy;
40+
import org.testcontainers.junit.jupiter.Container;
41+
import org.testcontainers.junit.jupiter.Testcontainers;
42+
import org.testcontainers.utility.DockerImageName;
43+
44+
import java.time.Duration;
45+
import java.time.temporal.ChronoUnit;
46+
import java.util.List;
47+
import java.util.Map;
4948

5049
import static org.assertj.core.api.Assertions.assertThat;
5150

@@ -73,14 +72,15 @@ class PostgresMlEmbeddingModelIT {
7372
@Autowired
7473
JdbcTemplate jdbcTemplate;
7574

76-
@AfterEach
75+
@BeforeEach
7776
void dropPgmlExtension() {
7877
this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS pgml");
7978
}
8079

8180
@Test
8281
void embed() {
83-
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(this.jdbcTemplate);
82+
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(this.jdbcTemplate,
83+
PostgresMlEmbeddingOptions.builder().build(), true);
8484
embeddingModel.afterPropertiesSet();
8585

8686
float[] embed = embeddingModel.embed("Hello World!");
@@ -94,7 +94,8 @@ void embedWithPgVector() {
9494
PostgresMlEmbeddingOptions.builder()
9595
.withTransformer("distilbert-base-uncased")
9696
.withVectorType(PostgresMlEmbeddingModel.VectorType.PG_VECTOR)
97-
.build());
97+
.build(),
98+
true);
9899
embeddingModel.afterPropertiesSet();
99100

100101
float[] embed = embeddingModel.embed(new Document("Hello World!"));
@@ -105,7 +106,7 @@ void embedWithPgVector() {
105106
@Test
106107
void embedWithDifferentModel() {
107108
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(this.jdbcTemplate,
108-
PostgresMlEmbeddingOptions.builder().withTransformer("intfloat/e5-small").build());
109+
PostgresMlEmbeddingOptions.builder().withTransformer("intfloat/e5-small").build(), true);
109110
embeddingModel.afterPropertiesSet();
110111

111112
float[] embed = embeddingModel.embed(new Document("Hello World!"));
@@ -121,7 +122,8 @@ void embedWithKwargs() {
121122
.withVectorType(PostgresMlEmbeddingModel.VectorType.PG_ARRAY)
122123
.withKwargs(Map.of("device", "cpu"))
123124
.withMetadataMode(MetadataMode.EMBED)
124-
.build());
125+
.build(),
126+
true);
125127
embeddingModel.afterPropertiesSet();
126128

127129
float[] embed = embeddingModel.embed(new Document("Hello World!"));
@@ -136,7 +138,8 @@ void embedForResponse(String vectorType) {
136138
PostgresMlEmbeddingOptions.builder()
137139
.withTransformer("distilbert-base-uncased")
138140
.withVectorType(VectorType.valueOf(vectorType))
139-
.build());
141+
.build(),
142+
true);
140143
embeddingModel.afterPropertiesSet();
141144

142145
EmbeddingResponse embeddingResponse = embeddingModel
@@ -174,7 +177,8 @@ void embedCallWithRequestOptionsOverride() {
174177
PostgresMlEmbeddingOptions.builder()
175178
.withTransformer("distilbert-base-uncased")
176179
.withVectorType(VectorType.PG_VECTOR)
177-
.build());
180+
.build(),
181+
true);
178182
embeddingModel.afterPropertiesSet();
179183

180184
var request1 = new EmbeddingRequest(List.of("Hello World!", "Spring AI!", "LLM!"), EmbeddingOptions.EMPTY);
@@ -244,7 +248,8 @@ void embedCallWithRequestOptionsOverride() {
244248

245249
@Test
246250
void dimensions() {
247-
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(this.jdbcTemplate);
251+
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(this.jdbcTemplate,
252+
PostgresMlEmbeddingOptions.builder().build(), true);
248253
embeddingModel.afterPropertiesSet();
249254
Assertions.assertThat(embeddingModel.dimensions()).isEqualTo(768);
250255
// cached

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ The prefix `spring.ai.postgresml.embedding` is property prefix that configures t
5050
|====
5151
| Property | Description | Default
5252
| spring.ai.postgresml.embedding.enabled | Enable PostgresML embedding model. | true
53+
| spring.ai.postgresml.embedding.create-extension | Execute the SQL 'CREATE EXTENSION IF NOT EXISTS pgml' to enable the extesnion | false
5354
| spring.ai.postgresml.embedding.options.transformer | The Hugging Face transformer model to use for the embedding. | distilbert-base-uncased
5455
| spring.ai.postgresml.embedding.options.kwargs | Additional transformer specific options. | empty map
5556
| spring.ai.postgresml.embedding.options.vectorType | PostgresML vector type to use for the embedding. Two options are supported: `PG_ARRAY` and `PG_VECTOR`. | PG_ARRAY

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ public class PostgresMlAutoConfiguration {
4444
public PostgresMlEmbeddingModel postgresMlEmbeddingModel(JdbcTemplate jdbcTemplate,
4545
PostgresMlEmbeddingProperties embeddingProperties) {
4646

47-
return new PostgresMlEmbeddingModel(jdbcTemplate, embeddingProperties.getOptions());
47+
return new PostgresMlEmbeddingModel(jdbcTemplate, embeddingProperties.getOptions(),
48+
embeddingProperties.isCreateExtension());
4849
}
4950

5051
}

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ public class PostgresMlEmbeddingProperties {
4141
*/
4242
private boolean enabled = true;
4343

44+
/**
45+
* Create the extensions required for embedding
46+
*/
47+
private boolean createExtension;
48+
4449
@NestedConfigurationProperty
4550
private PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder()
4651
.withTransformer(PostgresMlEmbeddingModel.DEFAULT_TRANSFORMER_MODEL)
@@ -71,4 +76,12 @@ public void setEnabled(boolean enabled) {
7176
this.enabled = enabled;
7277
}
7378

79+
public boolean isCreateExtension() {
80+
return createExtension;
81+
}
82+
83+
public void setCreateExtension(boolean createExtension) {
84+
this.createExtension = createExtension;
85+
}
86+
7487
}

0 commit comments

Comments
 (0)