Skip to content

Commit d9d0467

Browse files
committed
add vector search code
1 parent 6bd468e commit d9d0467

File tree

11 files changed

+346
-15
lines changed

11 files changed

+346
-15
lines changed

Java/CosmosDB-NoSQL-RAG-Chatbot/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ The following prerequisites are required to use this application. Please ensure
6161
mvn clean package
6262
```
6363

64-
4. The following command will read and process your own private text documents, create a Cosmos DB NoSQL API collection with [vector indexing](https://learn.microsoft.com/azure/cosmos-db/nosql/vector-search#vector-indexing-policies) and [embeddings](https://learn.microsoft.com/azure/cosmos-db/nosql/vector-search#container-vector-policies) policies (see `com.microsoft.azure.springchatgpt.sample.common.vectorstore.CosmosDBVectorStore.java`), and load the processed documents into it:
64+
4. The following command will read and process your own private text documents, create a Cosmos DB NoSQL API collection with [vector indexing](https://learn.microsoft.com/azure/cosmos-db/nosql/vector-search#vector-indexing-policies) and [embeddings](https://learn.microsoft.com/azure/cosmos-db/nosql/vector-search#container-vector-policies) policies (see `com.microsoft.azure.springchatgpt.sample.common.store.CosmosDBVectorStore.java`), and load the processed documents into it:
6565

6666
```shell
6767
java -jar spring-chatgpt-sample-cli/target/spring-chatgpt-sample-cli-0.0.1-SNAPSHOT.jar --from=C:/<path you your private text docs>

Java/CosmosDB-NoSQL-RAG-Chatbot/spring-chatgpt-sample-cli/src/main/java/com/microsoft/azure/spring/chatgpt/sample/cli/Config.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import com.azure.spring.data.cosmos.repository.config.EnableCosmosRepositories;
99
import com.microsoft.azure.spring.chatgpt.sample.common.AzureOpenAIClient;
1010
import com.microsoft.azure.spring.chatgpt.sample.common.DocumentIndexPlanner;
11-
import com.microsoft.azure.spring.chatgpt.sample.common.vectorstore.CosmosDBVectorStore;
12-
import com.microsoft.azure.spring.chatgpt.sample.common.vectorstore.CosmosEntityRepository;
13-
import com.microsoft.azure.spring.chatgpt.sample.common.vectorstore.CosmosProperties;
11+
import com.microsoft.azure.spring.chatgpt.sample.common.store.CosmosDBVectorStore;
12+
import com.microsoft.azure.spring.chatgpt.sample.common.store.CosmosEntityRepository;
13+
import com.microsoft.azure.spring.chatgpt.sample.common.store.CosmosProperties;
1414
import org.springframework.beans.factory.annotation.Autowired;
1515
import org.springframework.beans.factory.annotation.Value;
1616
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -22,7 +22,7 @@
2222

2323
@Configuration
2424
@EnableConfigurationProperties(CosmosProperties.class)
25-
@EnableCosmosRepositories(basePackages = "com.microsoft.azure.spring.chatgpt.sample.common.vectorstore")
25+
@EnableCosmosRepositories(basePackages = "com.microsoft.azure.spring.chatgpt.sample.common.store")
2626
public class Config extends AbstractCosmosConfiguration {
2727

2828
@Value("${AZURE_OPENAI_EMBEDDINGDEPLOYMENTID}")

Java/CosmosDB-NoSQL-RAG-Chatbot/spring-chatgpt-sample-common/src/main/java/com/microsoft/azure/spring/chatgpt/sample/common/ChatPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import com.azure.ai.openai.models.ChatMessage;
55
import com.azure.ai.openai.models.ChatRole;
66
import com.microsoft.azure.spring.chatgpt.sample.common.prompt.PromptTemplate;
7-
import com.microsoft.azure.spring.chatgpt.sample.common.vectorstore.CosmosEntity;
8-
import com.microsoft.azure.spring.chatgpt.sample.common.vectorstore.VectorStore;
7+
import com.microsoft.azure.spring.chatgpt.sample.common.store.CosmosEntity;
8+
import com.microsoft.azure.spring.chatgpt.sample.common.store.VectorStore;
99

1010
import java.util.ArrayList;
1111
import java.util.List;

Java/CosmosDB-NoSQL-RAG-Chatbot/spring-chatgpt-sample-common/src/main/java/com/microsoft/azure/spring/chatgpt/sample/common/DocumentIndexPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package com.microsoft.azure.spring.chatgpt.sample.common;
22

33
import com.microsoft.azure.spring.chatgpt.sample.common.reader.SimpleFolderReader;
4-
import com.microsoft.azure.spring.chatgpt.sample.common.vectorstore.CosmosDBVectorStore;
5-
import com.microsoft.azure.spring.chatgpt.sample.common.vectorstore.CosmosEntity;
4+
import com.microsoft.azure.spring.chatgpt.sample.common.store.CosmosDBVectorStore;
5+
import com.microsoft.azure.spring.chatgpt.sample.common.store.CosmosEntity;
66

77
import java.io.IOException;
88
import java.util.List;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.microsoft.azure.spring.chatgpt.sample.common.store;
2+
3+
import org.springframework.context.annotation.Configuration;
4+
5+
@Configuration
6+
public class Config {
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package com.microsoft.azure.spring.chatgpt.sample.common.store;
2+
3+
import com.azure.cosmos.CosmosAsyncClient;
4+
import com.azure.cosmos.CosmosAsyncDatabase;
5+
import com.azure.cosmos.implementation.guava25.collect.ImmutableList;
6+
import com.azure.cosmos.models.CosmosContainerProperties;
7+
import com.azure.cosmos.models.CosmosContainerResponse;
8+
import com.azure.cosmos.models.CosmosVectorDataType;
9+
import com.azure.cosmos.models.CosmosVectorDistanceFunction;
10+
import com.azure.cosmos.models.CosmosVectorEmbedding;
11+
import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy;
12+
import com.azure.cosmos.models.CosmosVectorIndexSpec;
13+
import com.azure.cosmos.models.CosmosVectorIndexType;
14+
import com.azure.cosmos.models.ExcludedPath;
15+
import com.azure.cosmos.models.IncludedPath;
16+
import com.azure.cosmos.models.IndexingMode;
17+
import com.azure.cosmos.models.IndexingPolicy;
18+
import com.azure.cosmos.models.ThroughputProperties;
19+
import com.azure.spring.data.cosmos.repository.config.EnableCosmosRepositories;
20+
import com.fasterxml.jackson.databind.ObjectMapper;
21+
import org.springframework.beans.factory.annotation.Autowired;
22+
import org.springframework.context.ApplicationContext;
23+
import org.springframework.stereotype.Component;
24+
import java.io.File;
25+
import java.util.ArrayList;
26+
import java.util.Arrays;
27+
import java.util.Collections;
28+
import java.util.List;
29+
import java.util.Map;
30+
import java.util.concurrent.ConcurrentHashMap;
31+
import java.util.logging.Logger;
32+
import java.util.stream.Collectors;
33+
34+
@Component
35+
@EnableCosmosRepositories (basePackages = "com.microsoft.azure.spring.chatgpt.sample.common.vectorstore")
36+
public class CosmosDBVectorStore implements VectorStore {
37+
38+
private final VectorStoreData data;
39+
40+
@Autowired
41+
private CosmosEntityRepository cosmosEntityRepository;
42+
43+
private String containerName;
44+
45+
private String databaseName;
46+
47+
private ApplicationContext applicationContext;
48+
49+
private Logger log = Logger.getLogger(CosmosDBVectorStore.class.getName());
50+
51+
public CosmosAsyncClient client;
52+
53+
public CosmosDBVectorStore(CosmosEntityRepository cosmosEntityRepository, String containerName, String databaseName, ApplicationContext applicationContext) {
54+
this.cosmosEntityRepository = cosmosEntityRepository;
55+
this.applicationContext = applicationContext;
56+
client = applicationContext.getBean(CosmosAsyncClient.class);
57+
this.containerName = containerName;
58+
this.databaseName = databaseName;
59+
this.data = new VectorStoreData();
60+
}
61+
62+
@Override
63+
public void saveDocument(String key, CosmosEntity doc) {
64+
cosmosEntityRepository.save(doc);
65+
}
66+
67+
@Override
68+
public CosmosEntity getDocument(String key) {
69+
var doc = cosmosEntityRepository.findById(key).get();
70+
return doc;
71+
}
72+
73+
@Override
74+
public void removeDocument(String key) {
75+
cosmosEntityRepository.deleteById(key);
76+
}
77+
78+
@Override
79+
public List<CosmosEntity> searchTopKNearest(List<Double> embedding, int k) {
80+
return searchTopKNearest(embedding, k, 0);
81+
}
82+
83+
@Override
84+
public List<CosmosEntity> searchTopKNearest(List<Double> embedding, int k, double cutOff) {
85+
Object embeddingParam = embedding.stream().map(aDouble -> (Float) (float) aDouble.doubleValue()).collect(Collectors.toList()).toArray();
86+
ArrayList<CosmosEntity> results = cosmosEntityRepository.vectorSearch(embeddingParam);
87+
return results;
88+
}
89+
90+
public void createVectorIndex(int numLists, int dimensions, String similarity) {
91+
92+
CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(containerName, "/id");
93+
94+
//set vector embedding policy
95+
CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy();
96+
CosmosVectorEmbedding embedding = new CosmosVectorEmbedding();
97+
embedding.setPath("/embedding");
98+
embedding.setDataType(CosmosVectorDataType.FLOAT32);
99+
embedding.setDimensions(1536L);
100+
embedding.setDistanceFunction(CosmosVectorDistanceFunction.COSINE);
101+
cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(Arrays.asList(embedding));
102+
collectionDefinition.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy);
103+
104+
//set vector indexing policy
105+
IndexingPolicy indexingPolicy = new IndexingPolicy();
106+
indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT);
107+
ExcludedPath excludedPath = new ExcludedPath("/*");
108+
indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath));
109+
IncludedPath includedPath1 = new IncludedPath("/hash/?");
110+
IncludedPath includedPath2 = new IncludedPath("/text/?");
111+
indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2));
112+
CosmosVectorIndexSpec cosmosVectorIndexSpec = new CosmosVectorIndexSpec();
113+
cosmosVectorIndexSpec.setPath("/embedding");
114+
cosmosVectorIndexSpec.setType(CosmosVectorIndexType.DISK_ANN.toString());
115+
indexingPolicy.setVectorIndexes(Arrays.asList(cosmosVectorIndexSpec));
116+
collectionDefinition.setIndexingPolicy(indexingPolicy);
117+
118+
//create container
119+
ThroughputProperties throughputProperties = ThroughputProperties.createManualThroughput(400);
120+
client.createDatabaseIfNotExists(databaseName).block();
121+
CosmosAsyncDatabase database = client.getDatabase(databaseName);
122+
CosmosContainerResponse containerResponse = database.createContainerIfNotExists(collectionDefinition, throughputProperties).block();
123+
}
124+
125+
public List<CosmosEntity> loadFromJsonFile(String filePath) {
126+
var reader = new ObjectMapper().reader();
127+
try {
128+
int dimensions = 0;
129+
var data = reader.readValue(new File(filePath), VectorStoreData.class);
130+
List<CosmosEntity> list = new ArrayList<CosmosEntity>(data.store.values());
131+
List<CosmosEntity> cosmosEntities = new ArrayList<>();
132+
try {
133+
createVectorIndex(100, dimensions, "COS");
134+
cosmosEntityRepository.saveAll(list);
135+
} catch (Exception e) {
136+
log.warning("Failed to insertAll documents to Cosmos DB NoSQL API, attempting individual upserts: "+ e.getMessage());
137+
for (CosmosEntity cosmosEntity : list) {
138+
log.info("Saving document {} to Cosmos DB NoSQL API" + cosmosEntity.getId());
139+
try {
140+
cosmosEntityRepository.save(cosmosEntity);
141+
} catch (Exception ex) {
142+
log.warning("Failed to upsert document "+ cosmosEntity.getId()+ "to Cosmos DB:" + ex);
143+
}
144+
}
145+
}
146+
return cosmosEntities;
147+
} catch (Exception e) {
148+
throw new RuntimeException(e);
149+
}
150+
}
151+
private static class VectorStoreData {
152+
public Map<String, CosmosEntity> getStore() {
153+
return store;
154+
}
155+
public void setStore(Map<String, CosmosEntity> store) {
156+
this.store = store;
157+
}
158+
private Map<String, CosmosEntity> store = new ConcurrentHashMap<>();
159+
}
160+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.microsoft.azure.spring.chatgpt.sample.common.store;
2+
3+
import com.azure.spring.data.cosmos.core.mapping.Container;
4+
import com.azure.spring.data.cosmos.core.mapping.PartitionKey;
5+
import org.springframework.data.annotation.Id;
6+
7+
import java.util.List;
8+
9+
@Container(containerName = "vectorstore", autoCreateContainer = false)
10+
public class CosmosEntity {
11+
@Id
12+
@PartitionKey
13+
private String id;
14+
private String hash;
15+
private String text;
16+
private List<Double> embedding;
17+
18+
public CosmosEntity() {}
19+
public CosmosEntity(String id, String hash, String text, List<Double> embedding) {
20+
this.id = id;
21+
this.hash = hash;
22+
this.text = text;
23+
this.embedding = embedding;
24+
}
25+
26+
public String getId() {
27+
return id;
28+
}
29+
30+
public void setId(String id) {
31+
this.id = id;
32+
}
33+
34+
public String getHash() {
35+
return hash;
36+
}
37+
38+
public void setHash(String hash) {
39+
this.hash = hash;
40+
}
41+
42+
public String getText() {
43+
return text;
44+
}
45+
46+
public void setText(String text) {
47+
this.text = text;
48+
}
49+
50+
public List<Double> getEmbedding() {
51+
return embedding;
52+
}
53+
54+
public void setEmbedding(List<Double> embedding) {
55+
this.embedding = embedding;
56+
}
57+
58+
@Override
59+
public String toString() {
60+
return "Vector{" +
61+
"id='" + id + '\'' +
62+
", hash='" + hash + '\'' +
63+
", text='" + text + '\'' +
64+
", embedding='" + embedding + '\'' +
65+
'}';
66+
}
67+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.microsoft.azure.spring.chatgpt.sample.common.store;
2+
import com.azure.spring.data.cosmos.repository.CosmosRepository;
3+
import com.azure.spring.data.cosmos.repository.Query;
4+
import org.springframework.data.repository.query.Param;
5+
import org.springframework.stereotype.Repository;
6+
import java.util.ArrayList;
7+
8+
@Repository
9+
public interface CosmosEntityRepository extends CosmosRepository<CosmosEntity, String> {
10+
@Query(value = "SELECT TOP 3 c.id, c.embedding, c.hash, c.text, VectorDistance(c.embedding,@embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.embedding,@embedding)")
11+
ArrayList<CosmosEntity> vectorSearch(@Param("embedding") Object embedding);
12+
13+
@Query(value = "SELECT c.id FROM c where c.id = @id")
14+
ArrayList<CosmosEntity> findRecord(@Param("embedding") String id);
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.microsoft.azure.spring.chatgpt.sample.common.store;
2+
3+
import org.springframework.boot.context.properties.ConfigurationProperties;
4+
5+
@ConfigurationProperties(prefix = "spring.data.cosmos")
6+
public class CosmosProperties {
7+
8+
private String uri;
9+
10+
private String key;
11+
12+
private String secondaryKey;
13+
14+
private String databaseName;
15+
16+
public String getContainerName() {
17+
return containerName;
18+
}
19+
20+
public void setContainerName(String containerName) {
21+
this.containerName = containerName;
22+
}
23+
24+
private String containerName;
25+
26+
private boolean queryMetricsEnabled;
27+
28+
public String getUri() {
29+
return uri;
30+
}
31+
32+
public void setUri(String uri) {
33+
this.uri = uri;
34+
}
35+
36+
public String getKey() {
37+
return key;
38+
}
39+
40+
public void setKey(String key) {
41+
this.key = key;
42+
}
43+
44+
public String getSecondaryKey() {
45+
return secondaryKey;
46+
}
47+
48+
public void setSecondaryKey(String secondaryKey) {
49+
this.secondaryKey = secondaryKey;
50+
}
51+
52+
public void setDatabaseName(String databaseName) {
53+
this.databaseName = databaseName;
54+
}
55+
56+
public String getDatabaseName() {
57+
return databaseName;
58+
}
59+
60+
public boolean isQueryMetricsEnabled() {
61+
return queryMetricsEnabled;
62+
}
63+
64+
public void setQueryMetricsEnabled(boolean enableQueryMetrics) {
65+
this.queryMetricsEnabled = enableQueryMetrics;
66+
}
67+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.microsoft.azure.spring.chatgpt.sample.common.store;
2+
3+
import java.util.List;
4+
5+
public interface VectorStore {
6+
void saveDocument(String key, CosmosEntity doc);
7+
8+
CosmosEntity getDocument(String key);
9+
10+
void removeDocument(String key);
11+
12+
List<CosmosEntity> searchTopKNearest(List<Double> embedding, int k);
13+
14+
List<CosmosEntity> searchTopKNearest(List<Double> embedding, int k, double cutOff);
15+
}

0 commit comments

Comments
 (0)