Skip to content

Commit f2a1732

Browse files
committed
feat: qdrant 정확성 수정
1 parent c87bf1c commit f2a1732

File tree

2 files changed

+90
-15
lines changed

2 files changed

+90
-15
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.ai.lawyer.global.qdrant.initializer;
2+
3+
import io.qdrant.client.QdrantClient;
4+
import jakarta.annotation.PostConstruct;
5+
import lombok.RequiredArgsConstructor;
6+
import lombok.extern.slf4j.Slf4j;
7+
import org.springframework.beans.factory.annotation.Value;
8+
import org.springframework.stereotype.Component;
9+
10+
import java.util.concurrent.ExecutionException;
11+
12+
@Slf4j
13+
@Component
14+
@RequiredArgsConstructor
15+
public class QdrantInitializer {
16+
17+
private final QdrantClient qdrantClient;
18+
19+
@Value("${spring.ai.vectorstore.qdrant.collection-name}")
20+
private String collectionName;
21+
22+
@Value("${spring.ai.vectorstore.qdrant.vector-size}")
23+
private Long vectorSize;
24+
25+
@PostConstruct
26+
private void existQdrantCollection() throws InterruptedException, ExecutionException {
27+
var collections = qdrantClient.listCollectionsAsync().get();
28+
boolean collectionExists = collections.stream()
29+
.anyMatch(collection -> collection.equals(collectionName));
30+
31+
if (!collectionExists) {
32+
log.info("'{}' 컬렉션이 존재하지 않아 새로 생성 중", collectionName);
33+
qdrantClient.createCollectionAsync(
34+
collectionName,
35+
io.qdrant.client.grpc.Collections.VectorParams.newBuilder()
36+
.setSize(vectorSize.intValue())
37+
.setDistance(io.qdrant.client.grpc.Collections.Distance.Cosine)
38+
.build()
39+
).get();
40+
log.info("'{}' 컬렉션 생성 완료", collectionName);
41+
} else {
42+
log.info("'{}' 컬렉션이 이미 존재합니다.", collectionName);
43+
}
44+
}
45+
46+
}

backend/src/main/java/com/ai/lawyer/global/qdrant/service/QdrantService.java

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,64 @@
77
import org.springframework.ai.vectorstore.filter.Filter;
88
import org.springframework.stereotype.Service;
99

10+
import java.util.ArrayList;
1011
import java.util.Collections;
12+
import java.util.Comparator;
1113
import java.util.List;
14+
import java.util.stream.Collectors;
1215

1316
@Service
1417
@RequiredArgsConstructor
1518
public class QdrantService {
1619

1720
private final VectorStore vectorStore;
1821

19-
public List<Document> searchDocument(String query, String key, String value, int topK) {
22+
public List<Document> searchDocument(String query, String key, String value) {
2023

21-
SearchRequest caseSearchRequest = SearchRequest.builder()
22-
.query(query)
23-
.topK(topK)
24+
SearchRequest findCaseNumberRequest = SearchRequest.builder()
25+
.query(query).topK(1)
2426
.filterExpression(new Filter.Expression(Filter.ExpressionType.EQ, new Filter.Key(key), new Filter.Value(value)))
2527
.build();
26-
List<Document> similarCaseDocuments = vectorStore.similaritySearch(caseSearchRequest);
27-
28-
if (caseSearchRequest == null) {
29-
return Collections.singletonList(
30-
Document.builder()
31-
.text("더미")
32-
.metadata(key, value)
33-
.score(0.0)
34-
.build()
35-
);
28+
List<Document> mostSimilarDocuments = vectorStore.similaritySearch(findCaseNumberRequest);
29+
30+
31+
if (mostSimilarDocuments.isEmpty()) {
32+
return Collections.emptyList();
33+
}
34+
String targetCaseNumber = (String) mostSimilarDocuments.get(0).getMetadata().get("caseNumber");
35+
if (targetCaseNumber == null) {
36+
return mostSimilarDocuments;
37+
}
38+
39+
SearchRequest fetchAllChunksRequest = SearchRequest.builder()
40+
.query(query).topK(100)
41+
.filterExpression(new Filter.Expression(Filter.ExpressionType.EQ, new Filter.Key("caseNumber"), new Filter.Value(targetCaseNumber)))
42+
.build();
43+
List<Document> allChunksOfCase = new ArrayList<>(vectorStore.similaritySearch(fetchAllChunksRequest));
44+
45+
if (allChunksOfCase.isEmpty()) {
46+
return Collections.emptyList();
3647
}
3748

38-
return similarCaseDocuments;
49+
allChunksOfCase.sort(Comparator.comparingInt(doc ->
50+
((Number) doc.getMetadata().get("chunkIndex")).intValue()
51+
));
52+
53+
String mergedContent = allChunksOfCase.stream()
54+
.map(Document::getText)
55+
.collect(Collectors.joining(""));
56+
57+
Document bestScoringDoc = allChunksOfCase.stream()
58+
.max(Comparator.comparing(Document::getScore))
59+
.orElse(allChunksOfCase.get(0));
60+
61+
Document finalDocument = Document.builder()
62+
.text(mergedContent)
63+
.metadata(bestScoringDoc.getMetadata())
64+
.score(bestScoringDoc.getScore())
65+
.build();
66+
67+
return Collections.singletonList(finalDocument);
3968
}
4069

4170
}

0 commit comments

Comments
 (0)