Skip to content

Commit 6439a84

Browse files
sobychackoStudiousXiaoYu
authored andcommitted
Move batching strategy to base vector store builder
Moving BatchingStrategy configuration from individual vector store implementations to the base AbstractVectorStoreBuilder to reduce code duplication and provide consistent batching behavior across all vector stores. The default TokenCountBatchingStrategy is now set in the base builder class. Signed-off-by: StudiousXiaoYu <[email protected]>
1 parent 508cb3e commit 6439a84

File tree

22 files changed

+35
-292
lines changed

22 files changed

+35
-292
lines changed

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/AbstractVectorStoreBuilder.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import io.micrometer.observation.ObservationRegistry;
2020

21+
import org.springframework.ai.embedding.BatchingStrategy;
2122
import org.springframework.ai.embedding.EmbeddingModel;
23+
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
2224
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
2325
import org.springframework.lang.Nullable;
2426
import org.springframework.util.Assert;
@@ -40,6 +42,8 @@ public abstract class AbstractVectorStoreBuilder<T extends AbstractVectorStoreBu
4042
@Nullable
4143
protected VectorStoreObservationConvention customObservationConvention;
4244

45+
protected BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
46+
4347
public AbstractVectorStoreBuilder(EmbeddingModel embeddingModel) {
4448
Assert.notNull(embeddingModel, "EmbeddingModel must be configured");
4549
this.embeddingModel = embeddingModel;
@@ -49,6 +53,10 @@ public EmbeddingModel getEmbeddingModel() {
4953
return this.embeddingModel;
5054
}
5155

56+
public BatchingStrategy getBatchingStrategy() {
57+
return this.batchingStrategy;
58+
}
59+
5260
public ObservationRegistry getObservationRegistry() {
5361
return this.observationRegistry;
5462
}
@@ -81,4 +89,15 @@ public T customObservationConvention(@Nullable VectorStoreObservationConvention
8189
return self();
8290
}
8391

92+
/**
93+
* Sets the batching strategy.
94+
* @param batchingStrategy the strategy to use
95+
* @return the builder instance
96+
*/
97+
public T batchingStrategy(BatchingStrategy batchingStrategy) {
98+
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
99+
this.batchingStrategy = batchingStrategy;
100+
return self();
101+
}
102+
84103
}

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.springframework.ai.document.Document;
2525
import org.springframework.ai.document.DocumentWriter;
26+
import org.springframework.ai.embedding.BatchingStrategy;
2627
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
2728
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
2829
import org.springframework.lang.Nullable;
@@ -108,6 +109,13 @@ interface Builder<T extends Builder<T>> {
108109
*/
109110
T customObservationConvention(VectorStoreObservationConvention convention);
110111

112+
/**
113+
* Sets the batching strategy.
114+
* @param batchingStrategy the strategy to use
115+
* @return the builder instance for method chaining
116+
*/
117+
T batchingStrategy(BatchingStrategy batchingStrategy);
118+
111119
/**
112120
* Builds and returns a new VectorStore instance with the configured settings.
113121
* @return a new VectorStore instance

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io.micrometer.observation.ObservationRegistry;
2323

2424
import org.springframework.ai.document.Document;
25+
import org.springframework.ai.embedding.BatchingStrategy;
2526
import org.springframework.ai.embedding.EmbeddingModel;
2627
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
2728
import org.springframework.ai.vectorstore.SearchRequest;
@@ -47,11 +48,14 @@ public abstract class AbstractObservationVectorStore implements VectorStore {
4748

4849
protected final EmbeddingModel embeddingModel;
4950

51+
protected final BatchingStrategy batchingStrategy;
52+
5053
private AbstractObservationVectorStore(EmbeddingModel embeddingModel, ObservationRegistry observationRegistry,
51-
@Nullable VectorStoreObservationConvention customObservationConvention) {
54+
@Nullable VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) {
5255
this.embeddingModel = embeddingModel;
5356
this.observationRegistry = observationRegistry;
5457
this.customObservationConvention = customObservationConvention;
58+
this.batchingStrategy = batchingStrategy;
5559
}
5660

5761
/**
@@ -60,7 +64,8 @@ private AbstractObservationVectorStore(EmbeddingModel embeddingModel, Observatio
6064
* @param builder the builder containing configuration settings
6165
*/
6266
public AbstractObservationVectorStore(AbstractVectorStoreBuilder<?> builder) {
63-
this(builder.getEmbeddingModel(), builder.getObservationRegistry(), builder.getCustomObservationConvention());
67+
this(builder.getEmbeddingModel(), builder.getObservationRegistry(), builder.getCustomObservationConvention(),
68+
builder.getBatchingStrategy());
6469
}
6570

6671
/**

vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
import org.springframework.ai.embedding.BatchingStrategy;
6161
import org.springframework.ai.embedding.EmbeddingModel;
6262
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
63-
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
6463
import org.springframework.ai.observation.conventions.VectorStoreProvider;
6564
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
6665
import org.springframework.ai.vectorstore.SearchRequest;
@@ -96,8 +95,6 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen
9695

9796
private final List<String> metadataFieldsList;
9897

99-
private final BatchingStrategy batchingStrategy;
100-
10198
private CosmosAsyncContainer container;
10299

103100
/**
@@ -120,7 +117,6 @@ protected CosmosDBVectorStore(Builder builder) {
120117
this.vectorStoreThroughput = builder.vectorStoreThroughput;
121118
this.vectorDimensions = builder.vectorDimensions;
122119
this.metadataFieldsList = builder.metadataFieldsList;
123-
this.batchingStrategy = builder.batchingStrategy;
124120

125121
this.cosmosClient.createDatabaseIfNotExists(this.databaseName).block();
126122
initializeContainer(this.containerName, this.databaseName, this.vectorStoreThroughput, this.vectorDimensions,
@@ -404,8 +400,6 @@ public static final class Builder extends AbstractVectorStoreBuilder<Builder> {
404400

405401
private List<String> metadataFieldsList = new ArrayList<>();
406402

407-
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
408-
409403
private Builder(CosmosAsyncClient cosmosClient, EmbeddingModel embeddingModel) {
410404
super(embeddingModel);
411405
Assert.notNull(cosmosClient, "CosmosClient must not be null");

vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,8 @@
4747

4848
import org.springframework.ai.document.Document;
4949
import org.springframework.ai.document.DocumentMetadata;
50-
import org.springframework.ai.embedding.BatchingStrategy;
5150
import org.springframework.ai.embedding.EmbeddingModel;
5251
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
53-
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
5452
import org.springframework.ai.model.EmbeddingUtils;
5553
import org.springframework.ai.observation.conventions.VectorStoreProvider;
5654
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
@@ -108,8 +106,6 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements
108106

109107
private final boolean initializeSchema;
110108

111-
private final BatchingStrategy batchingStrategy;
112-
113109
/**
114110
* List of metadata fields (as field name and type) that can be used in similarity
115111
* search query filter expressions. The {@link Document#getMetadata()} can contain
@@ -144,7 +140,6 @@ protected AzureVectorStore(Builder builder) {
144140
this.searchIndexClient = builder.searchIndexClient;
145141
this.initializeSchema = builder.initializeSchema;
146142
this.filterMetadataFields = builder.filterMetadataFields;
147-
this.batchingStrategy = builder.batchingStrategy;
148143
this.defaultTopK = builder.defaultTopK;
149144
this.defaultSimilarityThreshold = builder.defaultSimilarityThreshold;
150145
this.indexName = builder.indexName;
@@ -387,8 +382,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {
387382

388383
private List<MetadataField> filterMetadataFields = List.of();
389384

390-
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
391-
392385
private int defaultTopK = DEFAULT_TOP_K;
393386

394387
private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD;
@@ -421,17 +414,6 @@ public Builder filterMetadataFields(List<MetadataField> filterMetadataFields) {
421414
return this;
422415
}
423416

424-
/**
425-
* Sets the batching strategy.
426-
* @param batchingStrategy the strategy to use
427-
* @return the builder instance
428-
*/
429-
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
430-
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
431-
this.batchingStrategy = batchingStrategy;
432-
return this;
433-
}
434-
435417
/**
436418
* Sets the index name for the Azure Vector Store.
437419
* @param indexName the name of the index to use

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
215215

216216
private final boolean closeSessionOnClose;
217217

218-
private final BatchingStrategy batchingStrategy;
219-
220218
private final ConcurrentMap<Set<String>, PreparedStatement> addStmts = new ConcurrentHashMap<>();
221219

222220
private final PreparedStatement deleteStmt;
@@ -237,7 +235,6 @@ protected CassandraVectorStore(Builder builder) {
237235
this.primaryKeyTranslator = builder.primaryKeyTranslator;
238236
this.executor = Executors.newFixedThreadPool(builder.fixedThreadPoolExecutorSize);
239237
this.closeSessionOnClose = builder.closeSessionOnClose;
240-
this.batchingStrategy = builder.batchingStrategy;
241238

242239
ensureSchemaExists(embeddingModel.dimensions());
243240
prepareAddStatement(Set.of());
@@ -775,8 +772,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {
775772

776773
private int fixedThreadPoolExecutorSize = DEFAULT_ADD_CONCURRENCY;
777774

778-
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
779-
780775
private FilterExpressionConverter filterExpressionConverter;
781776

782777
private DocumentIdTranslator documentIdTranslator = (String id) -> List.of(id);
@@ -915,18 +910,6 @@ public Builder disallowSchemaChanges(boolean disallowSchemaChanges) {
915910
return this;
916911
}
917912

918-
/**
919-
* Sets the batching strategy.
920-
* @param batchingStrategy the batching strategy to use
921-
* @return the builder instance
922-
* @throws IllegalArgumentException if batchingStrategy is null
923-
*/
924-
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
925-
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
926-
this.batchingStrategy = batchingStrategy;
927-
return this;
928-
}
929-
930913
/**
931914
* Sets the filter expression converter.
932915
* @param converter the filter expression converter to use

vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements
7777

7878
private final boolean initializeSchema;
7979

80-
private final BatchingStrategy batchingStrategy;
81-
8280
private final ObjectMapper objectMapper;
8381

8482
private boolean initialized = false;
@@ -93,7 +91,6 @@ protected ChromaVectorStore(Builder builder) {
9391
this.collectionName = builder.collectionName;
9492
this.initializeSchema = builder.initializeSchema;
9593
this.filterExpressionConverter = builder.filterExpressionConverter;
96-
this.batchingStrategy = builder.batchingStrategy;
9794
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
9895

9996
if (builder.initializeImmediately) {
@@ -230,8 +227,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {
230227

231228
private boolean initializeSchema = false;
232229

233-
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
234-
235230
private FilterExpressionConverter filterExpressionConverter = new ChromaFilterExpressionConverter();
236231

237232
private boolean initializeImmediately = false;
@@ -264,18 +259,6 @@ public Builder initializeSchema(boolean initializeSchema) {
264259
return this;
265260
}
266261

267-
/**
268-
* Sets the batching strategy.
269-
* @param batchingStrategy the batching strategy to use
270-
* @return the builder instance
271-
* @throws IllegalArgumentException if batchingStrategy is null
272-
*/
273-
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
274-
Assert.notNull(batchingStrategy, "batchingStrategy must not be null");
275-
this.batchingStrategy = batchingStrategy;
276-
return this;
277-
}
278-
279262
/**
280263
* Sets the filter expression converter.
281264
* @param converter the filter expression converter to use

vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,6 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp
162162

163163
private final boolean initializeSchema;
164164

165-
private final BatchingStrategy batchingStrategy;
166-
167165
protected ElasticsearchVectorStore(Builder builder) {
168166
super(builder);
169167

@@ -172,7 +170,6 @@ protected ElasticsearchVectorStore(Builder builder) {
172170
this.initializeSchema = builder.initializeSchema;
173171
this.options = builder.options;
174172
this.filterExpressionConverter = builder.filterExpressionConverter;
175-
this.batchingStrategy = builder.batchingStrategy;
176173

177174
String version = Version.VERSION == null ? "Unknown" : Version.VERSION.toString();
178175
this.elasticsearchClient = new ElasticsearchClient(new RestClientTransport(builder.restClient,
@@ -369,8 +366,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {
369366

370367
private boolean initializeSchema = false;
371368

372-
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
373-
374369
private FilterExpressionConverter filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();
375370

376371
/**
@@ -406,18 +401,6 @@ public Builder initializeSchema(boolean initializeSchema) {
406401
return this;
407402
}
408403

409-
/**
410-
* Sets the batching strategy for vector operations.
411-
* @param batchingStrategy the batching strategy to use
412-
* @return the builder instance
413-
* @throws IllegalArgumentException if batchingStrategy is null
414-
*/
415-
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
416-
Assert.notNull(batchingStrategy, "batchingStrategy must not be null");
417-
this.batchingStrategy = batchingStrategy;
418-
return this;
419-
}
420-
421404
/**
422405
* Sets the filter expression converter.
423406
* @param converter the filter expression converter to use

vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/gemfire/GemFireVectorStore.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ public class GemFireVectorStore extends AbstractObservationVectorStore implement
103103

104104
private final boolean initializeSchema;
105105

106-
private final BatchingStrategy batchingStrategy;
107-
108106
private final ObjectMapper objectMapper;
109107

110108
private final String indexName;
@@ -134,7 +132,6 @@ protected GemFireVectorStore(Builder builder) {
134132
this.buckets = builder.buckets;
135133
this.vectorSimilarityFunction = builder.vectorSimilarityFunction;
136134
this.fields = builder.fields;
137-
this.batchingStrategy = builder.batchingStrategy;
138135

139136
String base = UriComponentsBuilder.fromUriString(DEFAULT_URI)
140137
.build(builder.sslEnabled ? "s" : "", builder.host, builder.port)
@@ -584,8 +581,6 @@ public static final class Builder extends AbstractVectorStoreBuilder<Builder> {
584581

585582
private boolean initializeSchema = false;
586583

587-
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
588-
589584
private Builder(EmbeddingModel embeddingModel) {
590585
super(embeddingModel);
591586
}
@@ -708,18 +703,6 @@ public Builder initializeSchema(boolean initializeSchema) {
708703
return this;
709704
}
710705

711-
/**
712-
* Sets the batching strategy.
713-
* @param batchingStrategy the strategy to use
714-
* @return the builder instance
715-
* @throws IllegalArgumentException if batchingStrategy is null
716-
*/
717-
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
718-
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
719-
this.batchingStrategy = batchingStrategy;
720-
return this;
721-
}
722-
723706
@Override
724707
public GemFireVectorStore build() {
725708
return new GemFireVectorStore(this);

0 commit comments

Comments
 (0)