Skip to content

Commit 4ca7dc0

Browse files
committed
feat: Update Milvus integration with new API, enhance collection management, and improve embedding storage
1 parent 72a6522 commit 4ca7dc0

File tree

4 files changed

+175
-57
lines changed

4 files changed

+175
-57
lines changed

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/application/KnowledgeBaseService.java

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717
import com.datamate.rag.indexer.infrastructure.milvus.MilvusService;
1818
import com.datamate.rag.indexer.interfaces.dto.*;
1919
import com.google.gson.JsonObject;
20-
import io.milvus.grpc.QueryResults;
21-
import io.milvus.param.R;
22-
import io.milvus.param.collection.DropCollectionParam;
23-
import io.milvus.param.collection.RenameCollectionParam;
24-
import io.milvus.param.dml.DeleteParam;
25-
import io.milvus.param.dml.QueryParam;
26-
import io.milvus.response.QueryResultsWrapper;
20+
import io.milvus.v2.service.collection.request.DropCollectionReq;
21+
import io.milvus.v2.service.collection.request.RenameCollectionReq;
22+
import io.milvus.v2.service.vector.request.DeleteReq;
23+
import io.milvus.v2.service.vector.request.QueryReq;
24+
import io.milvus.v2.service.vector.response.QueryResp;
2725
import lombok.RequiredArgsConstructor;
2826
import org.jetbrains.annotations.NotNull;
2927
import org.springframework.beans.BeanUtils;
@@ -34,7 +32,6 @@
3432

3533
import java.util.Collections;
3634
import java.util.List;
37-
import java.util.Map;
3835
import java.util.Optional;
3936

4037
/**
@@ -78,9 +75,9 @@ public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
7875
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
7976
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
8077
if (StringUtils.hasText(request.getName())) {
81-
milvusService.getMilvusClient().renameCollection(RenameCollectionParam.newBuilder()
82-
.withOldCollectionName(knowledgeBase.getName())
83-
.withNewCollectionName(request.getName())
78+
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
79+
.collectionName(knowledgeBase.getName())
80+
.newCollectionName(request.getName())
8481
.build());
8582
knowledgeBase.setName(request.getName());
8683
}
@@ -102,7 +99,7 @@ public void delete(String knowledgeBaseId) {
10299
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
103100
knowledgeBaseRepository.removeById(knowledgeBaseId);
104101
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
105-
milvusService.getMilvusClient().dropCollection(DropCollectionParam.newBuilder().withCollectionName(knowledgeBase.getName()).build());
102+
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
106103
}
107104

108105
public KnowledgeBaseResp getById(String knowledgeBaseId) {
@@ -168,41 +165,39 @@ public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
168165
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
169166
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
170167
ragFileRepository.removeByIds(request.getIds());
171-
milvusService.getMilvusClient().delete(DeleteParam.newBuilder()
172-
.withCollectionName(knowledgeBase.getName())
173-
.withExpr("metadata[\"rag_file_id\"] in [" + org.apache.commons.lang3.StringUtils.join(request.getIds().stream().map(id -> "\"" + id + "\"").toArray(), ",") + "]")
168+
milvusService.getMilvusClient().delete(DeleteReq.builder()
169+
.collectionName(knowledgeBase.getName())
170+
.filter("metadata[\"rag_file_id\"] in [" + org.apache.commons.lang3.StringUtils.join(request.getIds().stream().map(id -> "\"" + id + "\"").toArray(), ",") + "]")
174171
.build());
175172
}
176173

177174
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
178175
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
179176
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
180-
R<QueryResults> results = milvusService.getMilvusClient().query(QueryParam.newBuilder()
181-
.withCollectionName(knowledgeBase.getName())
182-
.withExpr("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
183-
.withOutFields(Collections.singletonList("*"))
184-
.withLimit(Long.valueOf(pagingQuery.getSize()))
185-
.withOffset((long) (pagingQuery.getPage() - 1) * pagingQuery.getSize())
177+
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
178+
.collectionName(knowledgeBase.getName())
179+
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
180+
.outputFields(Collections.singletonList("*"))
181+
.limit(Long.valueOf(pagingQuery.getSize()))
182+
.offset((long) (pagingQuery.getPage() - 1) * pagingQuery.getSize())
186183
.build());
187-
QueryResultsWrapper wrapper = new QueryResultsWrapper(results.getData());
188-
List<Map<String, Object>> list = wrapper.getRowRecords().stream().map(QueryResultsWrapper.RowRecord::getFieldValues).toList();
189-
190-
List<RagChunk> ragChunks = list.stream().map(item -> new RagChunk(
191-
item.get("id").toString(),
192-
item.get("text").toString(),
193-
item.get("metadata").toString()
194-
)).toList();
184+
List<QueryResp.QueryResult> queryResults = results.getQueryResults();
185+
List<RagChunk> ragChunks = queryResults.stream()
186+
.map(QueryResp.QueryResult::getEntity)
187+
.map(item -> new RagChunk(
188+
item.get("id").toString(),
189+
item.get("text").toString(),
190+
item.get("metadata").toString()
191+
)).toList();
195192

196193
// 获取总数
197-
R<QueryResults> countResults = milvusService.getMilvusClient().query(QueryParam.newBuilder()
198-
.withCollectionName(knowledgeBase.getName())
199-
.withExpr("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
200-
.withOutFields(Collections.singletonList("count(*)"))
194+
QueryResp countResults = milvusService.getMilvusClient().query(QueryReq.builder()
195+
.collectionName(knowledgeBase.getName())
196+
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
197+
.outputFields(Collections.singletonList("count(*)"))
201198
.build());
202-
QueryResultsWrapper countWrapper = new QueryResultsWrapper(countResults.getData());
203-
List<Map<String, Object>> countList = countWrapper.getRowRecords().stream().map(QueryResultsWrapper.RowRecord::getFieldValues).toList();
204-
long totalCount = Long.parseLong(countList.getFirst().get("count(*)").toString());
205199

200+
long totalCount = Long.parseLong(countResults.getQueryResults().getFirst().getEntity().get("count(*)").toString());
206201
return PagedResponse.of(ragChunks, pagingQuery.getPage(), totalCount, (int) Math.ceil((double) totalCount / pagingQuery.getSize()));
207202
}
208203
}

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/event/RagEtlService.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,19 @@ private void processRagFile(RagFile ragFile, DataInsertedEvent event) {
119119
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(model);
120120
// 调用嵌入模型获取嵌入向量
121121

122+
if (!milvusService.hasCollection(event.knowledgeBase().getName())) {
123+
milvusService.createCollection(event.knowledgeBase().getName(), embeddingModel.dimension());
124+
}
125+
122126
Lists.partition(split, 20).forEach(partition -> {
123-
List<Embedding> content = embeddingModel.embedAll(partition).content();
124-
// 存储嵌入向量到 Milvus
125-
milvusService.embeddingStore(embeddingModel, event.knowledgeBase().getName()).addAll(content, partition);
127+
List<Embedding> embeddings = embeddingModel.embedAll(partition).content();
128+
milvusService.addAll(event.knowledgeBase().getName(),partition, embeddings);
126129
});
127130
}
128131

129132
/**
130133
* 根据文件类型返回对应的文档解析器
131-
*
134+
*x
132135
* @param fileType 文件类型
133136
* @return 文档解析器
134137
*/

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/milvus/MilvusService.java

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
package com.datamate.rag.indexer.infrastructure.milvus;
22

3+
import com.google.gson.*;
4+
import dev.langchain4j.data.embedding.Embedding;
35
import dev.langchain4j.data.segment.TextSegment;
46
import dev.langchain4j.model.embedding.EmbeddingModel;
57
import dev.langchain4j.store.embedding.EmbeddingStore;
68
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
7-
import io.milvus.client.MilvusClient;
8-
import io.milvus.client.MilvusServiceClient;
9-
import io.milvus.param.ConnectParam;
9+
import io.milvus.common.clientenum.FunctionType;
10+
import io.milvus.v2.client.ConnectConfig;
11+
import io.milvus.v2.client.MilvusClientV2;
12+
import io.milvus.v2.common.DataType;
13+
import io.milvus.v2.common.IndexParam;
14+
import io.milvus.v2.service.collection.request.AddFieldReq;
15+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
16+
import io.milvus.v2.service.collection.request.HasCollectionReq;
17+
import io.milvus.v2.service.vector.request.InsertReq;
1018
import lombok.extern.slf4j.Slf4j;
1119
import org.springframework.beans.factory.annotation.Value;
1220
import org.springframework.stereotype.Component;
1321

22+
import java.util.*;
23+
24+
import static dev.langchain4j.internal.Utils.randomUUID;
25+
1426
/**
1527
* Milvus 服务类
1628
*
@@ -24,13 +36,19 @@ public class MilvusService {
2436
private String milvusHost;
2537
@Value("${datamate.rag.milvus-port:19530}")
2638
private int milvusPort;
39+
@Value("${datamate.rag.milvus-uri:http://milvus-standalone:19530}")
40+
private String milvusUri;
41+
private static final Gson GSON;
42+
43+
static {
44+
GSON = (new GsonBuilder()).setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE).create();
45+
}
2746

28-
private volatile MilvusClient milvusClient;
47+
private volatile MilvusClientV2 milvusClient;
2948

3049
public EmbeddingStore<TextSegment> embeddingStore(EmbeddingModel embeddingModel, String knowledgeBaseName) {
3150
return MilvusEmbeddingStore.builder()
32-
.host(milvusHost)
33-
.port(milvusPort)
51+
.uri(milvusUri)
3452
.collectionName(knowledgeBaseName)
3553
.dimension(embeddingModel.dimension())
3654
.build();
@@ -41,16 +59,15 @@ public EmbeddingStore<TextSegment> embeddingStore(EmbeddingModel embeddingModel,
4159
*
4260
* @return MilvusClient
4361
*/
44-
public MilvusClient getMilvusClient() {
62+
public MilvusClientV2 getMilvusClient() {
4563
if (milvusClient == null) {
4664
synchronized (this) {
4765
if (milvusClient == null) {
4866
try {
49-
ConnectParam connectParam = ConnectParam.newBuilder()
50-
.withHost(milvusHost)
51-
.withPort(milvusPort)
67+
ConnectConfig connectConfig = ConnectConfig.builder()
68+
.uri(milvusUri)
5269
.build();
53-
milvusClient = new MilvusServiceClient(connectParam);
70+
milvusClient = new MilvusClientV2(connectConfig);
5471
log.info("Milvus client connected successfully");
5572
} catch (Exception e) {
5673
log.error("Milvus client connection failed: {}", e.getMessage());
@@ -61,4 +78,107 @@ public MilvusClient getMilvusClient() {
6178
}
6279
return milvusClient;
6380
}
81+
82+
83+
public boolean hasCollection(String collectionName) {
84+
HasCollectionReq request = HasCollectionReq.builder().collectionName(collectionName).build();
85+
return getMilvusClient().hasCollection(request);
86+
}
87+
88+
public void createCollection(String collectionName, int dimension) {
89+
CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder()
90+
.build();
91+
schema.addField(AddFieldReq.builder()
92+
.fieldName("id")
93+
.dataType(DataType.VarChar)
94+
.maxLength(36)
95+
.isPrimaryKey(true)
96+
.autoID(false)
97+
.build());
98+
schema.addField(AddFieldReq.builder()
99+
.fieldName("text")
100+
.dataType(DataType.VarChar)
101+
.maxLength(65535)
102+
.enableAnalyzer(true)
103+
.build());
104+
schema.addField(AddFieldReq.builder()
105+
.fieldName("metadata")
106+
.dataType(DataType.JSON)
107+
.build());
108+
schema.addField(AddFieldReq.builder()
109+
.fieldName("vector")
110+
.dataType(DataType.FloatVector)
111+
.dimension(dimension)
112+
.build());
113+
schema.addField(AddFieldReq.builder()
114+
.fieldName("sparse")
115+
.dataType(DataType.SparseFloatVector)
116+
.build());
117+
schema.addFunction(CreateCollectionReq.Function.builder()
118+
.functionType(FunctionType.BM25)
119+
.name("text_bm25_emb")
120+
.inputFieldNames(Collections.singletonList("text"))
121+
.outputFieldNames(Collections.singletonList("sparse"))
122+
.build());
123+
124+
Map<String, Object> params = new HashMap<>();
125+
params.put("inverted_index_algo", "DAAT_MAXSCORE");
126+
params.put("bm25_k1", 1.2);
127+
params.put("bm25_b", 0.75);
128+
129+
List<IndexParam> indexes = new ArrayList<>();
130+
indexes.add(IndexParam.builder()
131+
.fieldName("sparse")
132+
.indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
133+
.metricType(IndexParam.MetricType.BM25)
134+
.extraParams(params)
135+
.build());
136+
indexes.add(IndexParam.builder()
137+
.fieldName("vector")
138+
.indexType(IndexParam.IndexType.FLAT)
139+
.metricType(IndexParam.MetricType.COSINE)
140+
.extraParams(Map.of())
141+
.build());
142+
143+
CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
144+
.collectionName(collectionName)
145+
.collectionSchema(schema)
146+
.indexParams(indexes)
147+
.build();
148+
this.getMilvusClient().createCollection(createCollectionReq);
149+
}
150+
151+
public void addAll(String collectionName, List<TextSegment> textSegments, List<Embedding> embeddings) {
152+
List<JsonObject> data = convertToJsonObjects(textSegments, embeddings);
153+
InsertReq insertReq = InsertReq.builder()
154+
.collectionName(collectionName)
155+
.data(data)
156+
.build();
157+
this.getMilvusClient().insert(insertReq);
158+
}
159+
160+
public List<JsonObject> convertToJsonObjects(List<TextSegment> textSegments, List<Embedding> embeddings) {
161+
List<JsonObject> data = new ArrayList<>();
162+
for (int i = 0; i < textSegments.size(); i++) {
163+
JsonObject jsonObject = new JsonObject();
164+
jsonObject.addProperty("id", randomUUID());
165+
jsonObject.addProperty("text", textSegments.get(i).text());
166+
jsonObject.add("metadata", GSON.toJsonTree(textSegments.get(i).metadata().toMap()).getAsJsonObject());
167+
JsonArray vectorArray = new JsonArray();
168+
for (float f : embeddings.get(i).vector()) {
169+
vectorArray.add(f);
170+
}
171+
jsonObject.add("vector", vectorArray);
172+
data.add(jsonObject);
173+
}
174+
return data;
175+
}
176+
177+
List<String> generateIds(int n) {
178+
List<String> ids = new ArrayList<>();
179+
for (int i = 0; i < n; i++) {
180+
ids.add(randomUUID());
181+
}
182+
return ids;
183+
}
64184
}

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/KnowledgeBaseController.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,10 @@ public PagedResponse<RagChunk> getChunks(@PathVariable("knowledgeBaseId") String
133133
*
134134
* @param knowledgeBaseId 知识库 ID
135135
*/
136-
@PostMapping("/{knowledgeBaseId}/retrieve")
137-
public PagedResponse<RagChunk> retrieve(@PathVariable("knowledgeBaseId") String knowledgeBaseId,
138-
@RequestBody @Valid RetrieveReq request,
139-
PagingQuery pagingQuery) {
140-
return knowledgeBaseService.retrieve(knowledgeBaseId, request, pagingQuery);
141-
}
136+
// @PostMapping("/{knowledgeBaseId}/retrieve")
137+
// public PagedResponse<RagChunk> retrieve(@PathVariable("knowledgeBaseId") String knowledgeBaseId,
138+
// @RequestBody @Valid RetrieveReq request,
139+
// PagingQuery pagingQuery) {
140+
// return knowledgeBaseService.retrieve(knowledgeBaseId, request, pagingQuery);
141+
// }
142142
}

0 commit comments

Comments
 (0)