Skip to content

Commit 891ef2a

Browse files
Til7701sobychacko
authored andcommitted
spring-projectsGH-3896: Similarity searches with the MariaDBVectorStore do not provide a score
Fixes spring-projects#3896 Auto-cherry-pick to 1.0.x Fixes spring-projectsGH-3896 (spring-projects#3896) In addition to provide the distance as part of the metadata, MariaDBVectorStore calculates the score from the distance and sets via the builder. Signed-off-by: Tilman Holube <[email protected]> Restore distance as part of metadata Signed-off-by: Soby Chacko <[email protected]>
1 parent 54a2dc3 commit 891ef2a

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,13 @@ public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
487487

488488
metadata.put("distance", distance);
489489

490-
return new Document(id, content, metadata);
490+
// @formatter:off
491+
return Document.builder()
492+
.id(id)
493+
.text(content)
494+
.metadata(metadata)
495+
.score(1.0 - distance)
496+
.build(); // @formatter:on
491497
}
492498

493499
private Map<String, Object> toMap(String source) {

vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,20 @@ static Stream<Arguments> provideFilters() {
122122
);
123123
}
124124

125-
private static boolean isSortedByDistance(List<Document> docs) {
125+
private static boolean isSortedByScore(List<Document> docs) {
126126

127-
List<Float> distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
127+
List<Double> scores = docs.stream().map(Document::getScore).toList();
128128

129-
if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {
129+
if (CollectionUtils.isEmpty(scores) || scores.size() == 1) {
130130
return true;
131131
}
132132

133-
Iterator<Float> iter = distances.iterator();
134-
Float current;
135-
Float previous = iter.next();
133+
Iterator<Double> iter = scores.iterator();
134+
Double current;
135+
Double previous = iter.next();
136136
while (iter.hasNext()) {
137137
current = iter.next();
138-
if (previous > current) {
138+
if (previous < current) {
139139
return false;
140140
}
141141
previous = current;
@@ -166,7 +166,8 @@ public void addAndSearch(String distanceType) {
166166
assertThat(results).hasSize(1);
167167
Document resultDoc = results.get(0);
168168
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
169-
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
169+
assertThat(resultDoc.getMetadata()).containsKeys("meta2");
170+
assertThat(resultDoc.getScore()).isBetween(0.0, 1.0);
170171

171172
// Remove all documents from the store
172173
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -315,7 +316,8 @@ public void documentUpdate(String distanceType) {
315316
Document resultDoc = results.get(0);
316317
assertThat(resultDoc.getId()).isEqualTo(document.getId());
317318
assertThat(resultDoc.getText()).isEqualTo("Spring AI rocks!!");
318-
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
319+
assertThat(resultDoc.getMetadata()).containsKeys("meta1");
320+
assertThat(resultDoc.getScore()).isBetween(0.0, 1.0);
319321

320322
Document sameIdDocument = new Document(document.getId(),
321323
"The World is Big and Salvation Lurks Around the Corner",
@@ -329,7 +331,8 @@ public void documentUpdate(String distanceType) {
329331
resultDoc = results.get(0);
330332
assertThat(resultDoc.getId()).isEqualTo(document.getId());
331333
assertThat(resultDoc.getText()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
332-
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
334+
assertThat(resultDoc.getMetadata()).containsKeys("meta2");
335+
assertThat(resultDoc.getScore()).isBetween(0.0, 1.0);
333336

334337
dropTable(context);
335338
});
@@ -350,19 +353,14 @@ public void searchWithThreshold(String distanceType) {
350353

351354
assertThat(fullResult).hasSize(3);
352355

353-
assertThat(isSortedByDistance(fullResult)).isTrue();
356+
assertThat(isSortedByScore(fullResult)).isTrue();
354357

355-
List<Float> distances = fullResult.stream()
356-
.map(doc -> (Float) doc.getMetadata().get("distance"))
357-
.toList();
358+
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
358359

359-
float threshold = (distances.get(0) + distances.get(1)) / 2;
360+
double threshold = (scores.get(0) + scores.get(1)) / 2;
360361

361-
List<Document> results = vectorStore.similaritySearch(SearchRequest.builder()
362-
.query("Time Shelter")
363-
.topK(5)
364-
.similarityThreshold(1 - threshold)
365-
.build());
362+
List<Document> results = vectorStore.similaritySearch(
363+
SearchRequest.builder().query("Time Shelter").topK(5).similarityThreshold(threshold).build());
366364

367365
assertThat(results).hasSize(1);
368366
Document resultDoc = results.get(0);

0 commit comments

Comments
 (0)