|
| 1 | +package com.microsoft.azure.spring.chatgpt.sample.common.store; |
| 2 | + |
| 3 | +import com.azure.cosmos.CosmosAsyncClient; |
| 4 | +import com.azure.cosmos.CosmosAsyncDatabase; |
| 5 | +import com.azure.cosmos.implementation.guava25.collect.ImmutableList; |
| 6 | +import com.azure.cosmos.models.CosmosContainerProperties; |
| 7 | +import com.azure.cosmos.models.CosmosContainerResponse; |
| 8 | +import com.azure.cosmos.models.CosmosVectorDataType; |
| 9 | +import com.azure.cosmos.models.CosmosVectorDistanceFunction; |
| 10 | +import com.azure.cosmos.models.CosmosVectorEmbedding; |
| 11 | +import com.azure.cosmos.models.CosmosVectorEmbeddingPolicy; |
| 12 | +import com.azure.cosmos.models.CosmosVectorIndexSpec; |
| 13 | +import com.azure.cosmos.models.CosmosVectorIndexType; |
| 14 | +import com.azure.cosmos.models.ExcludedPath; |
| 15 | +import com.azure.cosmos.models.IncludedPath; |
| 16 | +import com.azure.cosmos.models.IndexingMode; |
| 17 | +import com.azure.cosmos.models.IndexingPolicy; |
| 18 | +import com.azure.cosmos.models.ThroughputProperties; |
| 19 | +import com.azure.spring.data.cosmos.repository.config.EnableCosmosRepositories; |
| 20 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 21 | +import org.springframework.beans.factory.annotation.Autowired; |
| 22 | +import org.springframework.context.ApplicationContext; |
| 23 | +import org.springframework.stereotype.Component; |
| 24 | +import java.io.File; |
| 25 | +import java.util.ArrayList; |
| 26 | +import java.util.Arrays; |
| 27 | +import java.util.Collections; |
| 28 | +import java.util.List; |
| 29 | +import java.util.Map; |
| 30 | +import java.util.concurrent.ConcurrentHashMap; |
| 31 | +import java.util.logging.Logger; |
| 32 | +import java.util.stream.Collectors; |
| 33 | + |
| 34 | +@Component |
| 35 | +@EnableCosmosRepositories (basePackages = "com.microsoft.azure.spring.chatgpt.sample.common.vectorstore") |
| 36 | +public class CosmosDBVectorStore implements VectorStore { |
| 37 | + |
| 38 | + private final VectorStoreData data; |
| 39 | + |
| 40 | + @Autowired |
| 41 | + private CosmosEntityRepository cosmosEntityRepository; |
| 42 | + |
| 43 | + private String containerName; |
| 44 | + |
| 45 | + private String databaseName; |
| 46 | + |
| 47 | + private ApplicationContext applicationContext; |
| 48 | + |
| 49 | + private Logger log = Logger.getLogger(CosmosDBVectorStore.class.getName()); |
| 50 | + |
| 51 | + public CosmosAsyncClient client; |
| 52 | + |
| 53 | + public CosmosDBVectorStore(CosmosEntityRepository cosmosEntityRepository, String containerName, String databaseName, ApplicationContext applicationContext) { |
| 54 | + this.cosmosEntityRepository = cosmosEntityRepository; |
| 55 | + this.applicationContext = applicationContext; |
| 56 | + client = applicationContext.getBean(CosmosAsyncClient.class); |
| 57 | + this.containerName = containerName; |
| 58 | + this.databaseName = databaseName; |
| 59 | + this.data = new VectorStoreData(); |
| 60 | + } |
| 61 | + |
| 62 | + @Override |
| 63 | + public void saveDocument(String key, CosmosEntity doc) { |
| 64 | + cosmosEntityRepository.save(doc); |
| 65 | + } |
| 66 | + |
| 67 | + @Override |
| 68 | + public CosmosEntity getDocument(String key) { |
| 69 | + var doc = cosmosEntityRepository.findById(key).get(); |
| 70 | + return doc; |
| 71 | + } |
| 72 | + |
| 73 | + @Override |
| 74 | + public void removeDocument(String key) { |
| 75 | + cosmosEntityRepository.deleteById(key); |
| 76 | + } |
| 77 | + |
| 78 | + @Override |
| 79 | + public List<CosmosEntity> searchTopKNearest(List<Double> embedding, int k) { |
| 80 | + return searchTopKNearest(embedding, k, 0); |
| 81 | + } |
| 82 | + |
| 83 | + @Override |
| 84 | + public List<CosmosEntity> searchTopKNearest(List<Double> embedding, int k, double cutOff) { |
| 85 | + Object embeddingParam = embedding.stream().map(aDouble -> (Float) (float) aDouble.doubleValue()).collect(Collectors.toList()).toArray(); |
| 86 | + ArrayList<CosmosEntity> results = cosmosEntityRepository.vectorSearch(embeddingParam); |
| 87 | + return results; |
| 88 | + } |
| 89 | + |
| 90 | + public void createVectorIndex(int numLists, int dimensions, String similarity) { |
| 91 | + |
| 92 | + CosmosContainerProperties collectionDefinition = new CosmosContainerProperties(containerName, "/id"); |
| 93 | + |
| 94 | + //set vector embedding policy |
| 95 | + CosmosVectorEmbeddingPolicy cosmosVectorEmbeddingPolicy = new CosmosVectorEmbeddingPolicy(); |
| 96 | + CosmosVectorEmbedding embedding = new CosmosVectorEmbedding(); |
| 97 | + embedding.setPath("/embedding"); |
| 98 | + embedding.setDataType(CosmosVectorDataType.FLOAT32); |
| 99 | + embedding.setDimensions(1536L); |
| 100 | + embedding.setDistanceFunction(CosmosVectorDistanceFunction.COSINE); |
| 101 | + cosmosVectorEmbeddingPolicy.setCosmosVectorEmbeddings(Arrays.asList(embedding)); |
| 102 | + collectionDefinition.setVectorEmbeddingPolicy(cosmosVectorEmbeddingPolicy); |
| 103 | + |
| 104 | + //set vector indexing policy |
| 105 | + IndexingPolicy indexingPolicy = new IndexingPolicy(); |
| 106 | + indexingPolicy.setIndexingMode(IndexingMode.CONSISTENT); |
| 107 | + ExcludedPath excludedPath = new ExcludedPath("/*"); |
| 108 | + indexingPolicy.setExcludedPaths(Collections.singletonList(excludedPath)); |
| 109 | + IncludedPath includedPath1 = new IncludedPath("/hash/?"); |
| 110 | + IncludedPath includedPath2 = new IncludedPath("/text/?"); |
| 111 | + indexingPolicy.setIncludedPaths(ImmutableList.of(includedPath1, includedPath2)); |
| 112 | + CosmosVectorIndexSpec cosmosVectorIndexSpec = new CosmosVectorIndexSpec(); |
| 113 | + cosmosVectorIndexSpec.setPath("/embedding"); |
| 114 | + cosmosVectorIndexSpec.setType(CosmosVectorIndexType.DISK_ANN.toString()); |
| 115 | + indexingPolicy.setVectorIndexes(Arrays.asList(cosmosVectorIndexSpec)); |
| 116 | + collectionDefinition.setIndexingPolicy(indexingPolicy); |
| 117 | + |
| 118 | + //create container |
| 119 | + ThroughputProperties throughputProperties = ThroughputProperties.createManualThroughput(400); |
| 120 | + client.createDatabaseIfNotExists(databaseName).block(); |
| 121 | + CosmosAsyncDatabase database = client.getDatabase(databaseName); |
| 122 | + CosmosContainerResponse containerResponse = database.createContainerIfNotExists(collectionDefinition, throughputProperties).block(); |
| 123 | + } |
| 124 | + |
| 125 | + public List<CosmosEntity> loadFromJsonFile(String filePath) { |
| 126 | + var reader = new ObjectMapper().reader(); |
| 127 | + try { |
| 128 | + int dimensions = 0; |
| 129 | + var data = reader.readValue(new File(filePath), VectorStoreData.class); |
| 130 | + List<CosmosEntity> list = new ArrayList<CosmosEntity>(data.store.values()); |
| 131 | + List<CosmosEntity> cosmosEntities = new ArrayList<>(); |
| 132 | + try { |
| 133 | + createVectorIndex(100, dimensions, "COS"); |
| 134 | + cosmosEntityRepository.saveAll(list); |
| 135 | + } catch (Exception e) { |
| 136 | + log.warning("Failed to insertAll documents to Cosmos DB NoSQL API, attempting individual upserts: "+ e.getMessage()); |
| 137 | + for (CosmosEntity cosmosEntity : list) { |
| 138 | + log.info("Saving document {} to Cosmos DB NoSQL API" + cosmosEntity.getId()); |
| 139 | + try { |
| 140 | + cosmosEntityRepository.save(cosmosEntity); |
| 141 | + } catch (Exception ex) { |
| 142 | + log.warning("Failed to upsert document "+ cosmosEntity.getId()+ "to Cosmos DB:" + ex); |
| 143 | + } |
| 144 | + } |
| 145 | + } |
| 146 | + return cosmosEntities; |
| 147 | + } catch (Exception e) { |
| 148 | + throw new RuntimeException(e); |
| 149 | + } |
| 150 | + } |
| 151 | + private static class VectorStoreData { |
| 152 | + public Map<String, CosmosEntity> getStore() { |
| 153 | + return store; |
| 154 | + } |
| 155 | + public void setStore(Map<String, CosmosEntity> store) { |
| 156 | + this.store = store; |
| 157 | + } |
| 158 | + private Map<String, CosmosEntity> store = new ConcurrentHashMap<>(); |
| 159 | + } |
| 160 | +} |
0 commit comments