Skip to content

Commit bc26cfb

Browse files
authored
feat: Refactor knowledge base retrieval to return detailed search results and enhance API integration #108
1 parent b50c12d commit bc26cfb

File tree

9 files changed

+169
-76
lines changed

9 files changed

+169
-76
lines changed

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/application/KnowledgeBaseService.java

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,14 @@ public String create(KnowledgeBaseCreateReq request) {
7676
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
7777
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
7878
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
79-
if (StringUtils.hasText(request.getName())) {
79+
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
8080
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
8181
.collectionName(knowledgeBase.getName())
8282
.newCollectionName(request.getName())
8383
.build());
8484
knowledgeBase.setName(request.getName());
8585
}
86-
if (StringUtils.hasText(request.getDescription())) {
87-
knowledgeBase.setDescription(request.getDescription());
88-
}
86+
knowledgeBase.setDescription(request.getDescription());
8987
knowledgeBaseRepository.updateById(knowledgeBase);
9088
}
9189

@@ -147,7 +145,7 @@ public void addFiles(AddFilesReq request) {
147145
RagFile ragFile = new RagFile();
148146
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
149147
ragFile.setFileId(fileInfo.id());
150-
ragFile.setFileName(fileInfo.name());
148+
ragFile.setFileName(fileInfo.fileName());
151149
ragFile.setStatus(FileStatus.UNPROCESSED);
152150
return ragFile;
153151
}).toList();
@@ -209,23 +207,19 @@ public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileI
209207
* @param request 检索请求
210208
* @return 检索结果
211209
*/
212-
public SearchResp retrieve(RetrieveReq request) {
210+
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
213211
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
214212
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
215213
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
216214
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
217215
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
218216
SearchResp searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
219-
return searchResp;
220-
221-
// request.getKnowledgeBaseIds().forEach(knowledgeId -> {
222-
// KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeId))
223-
// .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
224-
// ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
225-
// EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
226-
// Embedding embedding = embeddingModel.embed(request.getQuery()).content();
227-
// searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
228-
// });
229-
// return searchResp;
217+
List<SearchResp.SearchResult> searchResults = searchResp.getSearchResults().getFirst();
218+
219+
searchResults.forEach(item -> {
220+
String metadata = item.getEntity().get("metadata").toString();
221+
item.getEntity().put("metadata", metadata);
222+
});
223+
return searchResults;
230224
}
231225
}

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/milvus/MilvusService.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import io.milvus.v2.service.collection.request.CreateCollectionReq;
1616
import io.milvus.v2.service.collection.request.HasCollectionReq;
1717
import io.milvus.v2.service.vector.request.AnnSearchReq;
18+
import io.milvus.v2.service.vector.request.FunctionScore;
1819
import io.milvus.v2.service.vector.request.HybridSearchReq;
1920
import io.milvus.v2.service.vector.request.InsertReq;
2021
import io.milvus.v2.service.vector.request.data.BaseVector;
@@ -197,21 +198,26 @@ public SearchResp hybridSearch(String collectionName, String query, float[] quer
197198
.params("{\"drop_ratio_search\": 0.2}")
198199
.topK(topK)
199200
.build());
201+
200202
CreateCollectionReq.Function ranker = CreateCollectionReq.Function.builder()
201-
.name("rrf")
203+
.name("weight")
202204
.functionType(FunctionType.RERANK)
203-
.param("reranker", "rrf")
204-
.param("k", "60")
205+
.param("reranker", "weighted")
206+
.param("weights", "[0.1, 0.9]")
207+
.param("norm_score", "true")
205208
.build();
206209

210+
FunctionScore functionScore = FunctionScore.builder()
211+
.functions(Collections.singletonList(ranker))
212+
.build();
207213

208214

209215
SearchResp searchResp = this.getMilvusClient().hybridSearch(HybridSearchReq.builder()
210216
.collectionName(collectionName)
211217
.searchRequests(searchRequests)
212-
.ranker(ranker)
218+
.functionScore(functionScore)
213219
.outFields(Arrays.asList("id", "text", "metadata"))
214-
.topK(topK)
220+
.limit(topK)
215221
.build());
216222
return searchResp;
217223
}

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/KnowledgeBaseController.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import lombok.RequiredArgsConstructor;
1212
import org.springframework.web.bind.annotation.*;
1313

14+
import java.util.List;
15+
1416

1517
/**
1618
* 知识库控制器
@@ -136,7 +138,7 @@ public PagedResponse<RagChunk> getChunks(@PathVariable("knowledgeBaseId") String
136138
* @return 检索结果
137139
*/
138140
@PostMapping("/retrieve")
139-
public SearchResp retrieve(@RequestBody @Valid RetrieveReq request) {
141+
public List<SearchResp.SearchResult> retrieve(@RequestBody @Valid RetrieveReq request) {
140142
return knowledgeBaseService.retrieve(request);
141143
}
142144
}

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/AddFilesReq.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ public class AddFilesReq {
2121
private String delimiter;
2222
private List<FileInfo> files;
2323

24-
public record FileInfo(String id, String name) {
24+
public record FileInfo(String id, String fileName) {
2525
}
2626
}

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/KnowledgeBaseCreateReq.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ public class KnowledgeBaseCreateReq {
2020
*/
2121
@NotEmpty(message = "知识库名称不能为空")
2222
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
23-
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
23+
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
2424
private String name;
2525
/**
2626
* 知识库描述
2727
*/
28-
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
28+
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
2929
private String description;
3030

3131
/**

backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/KnowledgeBaseUpdateReq.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ public class KnowledgeBaseUpdateReq {
2020
*/
2121
@NotEmpty(message = "知识库名称不能为空")
2222
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
23-
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
23+
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
2424
private String name;
2525
/**
2626
* 知识库描述
2727
*/
28-
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
28+
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
2929
private String description;
3030
}

0 commit comments

Comments
 (0)