Skip to content

Commit 3a564cd

Browse files
committed
align changes
1 parent 3afec61 commit 3a564cd

File tree

9 files changed

+266
-119
lines changed

9 files changed

+266
-119
lines changed

openmetadata-service/src/main/java/org/openmetadata/service/search/vector/OpenSearchVectorService.java

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.fasterxml.jackson.databind.JsonNode;
44
import com.fasterxml.jackson.databind.ObjectMapper;
55
import jakarta.json.stream.JsonParser;
6+
import java.io.IOException;
67
import java.io.InputStream;
78
import java.io.StringReader;
89
import java.nio.charset.StandardCharsets;
@@ -86,70 +87,78 @@ public void close() {
8687
}
8788

8889
@Override
90+
@SuppressWarnings("unchecked")
8991
public VectorSearchResponse search(
9092
String query, Map<String, List<String>> filters, int size, int k, double threshold) {
9193
long start = System.currentTimeMillis();
9294
try {
9395
float[] queryVector = embeddingClient.embed(query);
9496
int overFetchSize = size * OVER_FETCH_MULTIPLIER;
95-
int overFetchK = k * OVER_FETCH_MULTIPLIER;
9697

97-
String queryJson =
98-
VectorSearchQueryBuilder.build(queryVector, overFetchSize, overFetchK, filters);
98+
String queryJson = VectorSearchQueryBuilder.build(queryVector, overFetchSize, k, filters);
9999
String indexName = getClusteredIndexName();
100100
String responseBody = executeGenericRequest("POST", "/" + indexName + "/_search", queryJson);
101101

102102
JsonNode root = MAPPER.readTree(responseBody);
103103
JsonNode hitsNode = root.path("hits").path("hits");
104104

105-
LinkedHashMap<String, Map<String, Object>> parentGroups = new LinkedHashMap<>();
105+
LinkedHashMap<String, List<Map<String, Object>>> byParent = new LinkedHashMap<>();
106106
for (JsonNode hit : hitsNode) {
107107
double score = hit.path("_score").asDouble(0.0);
108-
if (threshold > 0 && score < threshold) {
108+
if (score < threshold) {
109109
continue;
110110
}
111111

112-
JsonNode source = hit.path("_source");
113-
String parentId = source.path("parent_id").asText(hit.path("_id").asText());
112+
Map<String, Object> hitMap = MAPPER.convertValue(hit.path("_source"), Map.class);
113+
hitMap.put("_score", score);
114114

115-
if (!parentGroups.containsKey(parentId)) {
116-
Map<String, Object> hitMap = MAPPER.convertValue(source, Map.class);
117-
hitMap.put("_id", hit.path("_id").asText());
118-
hitMap.put("_score", score);
119-
parentGroups.put(parentId, hitMap);
115+
String parentId = (String) hitMap.get("parent_id");
116+
if (parentId != null) {
117+
byParent.computeIfAbsent(parentId, kVal -> new ArrayList<>()).add(hitMap);
120118
}
119+
}
121120

122-
if (parentGroups.size() >= size) {
121+
List<Map<String, Object>> results = new ArrayList<>();
122+
int parentCount = 0;
123+
for (List<Map<String, Object>> chunks : byParent.values()) {
124+
if (parentCount >= size) {
123125
break;
124126
}
127+
results.addAll(chunks);
128+
parentCount++;
125129
}
126130

127131
long tookMillis = System.currentTimeMillis() - start;
128-
return new VectorSearchResponse(tookMillis, new ArrayList<>(parentGroups.values()));
132+
return new VectorSearchResponse(tookMillis, results);
129133
} catch (Exception e) {
130134
LOG.error("Vector search failed: {}", e.getMessage(), e);
131135
long tookMillis = System.currentTimeMillis() - start;
132136
return new VectorSearchResponse(tookMillis, Collections.emptyList());
133137
}
134138
}
135139

136-
@SuppressWarnings("unchecked")
137140
String executeGenericRequest(String method, String endpoint, String body) {
138141
try {
139142
OpenSearchGenericClient genericClient = client.generic();
140143
var request = Requests.builder().endpoint(endpoint).method(method).json(body).build();
141-
var response = genericClient.execute(request);
142-
return response
143-
.getBody()
144-
.map(
145-
b -> {
146-
try {
147-
return new String(b.bodyAsBytes(), StandardCharsets.UTF_8);
148-
} catch (Exception e) {
149-
return "{}";
150-
}
151-
})
152-
.orElse("{}");
144+
try (var response = genericClient.execute(request)) {
145+
if (response.getStatus() >= 400) {
146+
String errorBody = response.getBody().map(b -> b.bodyAsString()).orElse("no body");
147+
throw new IOException(
148+
"OpenSearch request failed with status " + response.getStatus() + ": " + errorBody);
149+
}
150+
return response
151+
.getBody()
152+
.map(
153+
b -> {
154+
try {
155+
return new String(b.bodyAsBytes(), StandardCharsets.UTF_8);
156+
} catch (Exception e) {
157+
return "{}";
158+
}
159+
})
160+
.orElse("{}");
161+
}
153162
} catch (Exception e) {
154163
LOG.error("Generic request failed: {} {}", method, endpoint, e);
155164
throw new RuntimeException("OpenSearch generic request failed", e);
@@ -186,11 +195,22 @@ public void updateVectorEmbeddingsWithMigration(
186195
try {
187196
String parentId = entity.getId().toString();
188197
String currentFingerprint = VectorDocBuilder.computeFingerprintForEntity(entity);
189-
String existingFingerprint = getExistingFingerprint(sourceIndex, parentId);
190198

191-
if (currentFingerprint.equals(existingFingerprint)) {
192-
copyExistingVectorDocuments(sourceIndex, targetIndex, parentId, currentFingerprint);
193-
return;
199+
if (sourceIndex != null) {
200+
try {
201+
String existingFingerprint = getExistingFingerprint(sourceIndex, parentId);
202+
if (currentFingerprint.equals(existingFingerprint)) {
203+
if (copyExistingVectorDocuments(
204+
sourceIndex, targetIndex, parentId, currentFingerprint)) {
205+
return;
206+
}
207+
}
208+
} catch (Exception ex) {
209+
LOG.warn(
210+
"Migration copy failed for entity {}, falling back to recomputation: {}",
211+
parentId,
212+
ex.getMessage());
213+
}
194214
}
195215

196216
List<Map<String, Object>> docs = VectorDocBuilder.fromEntity(entity, embeddingClient);
@@ -274,7 +294,8 @@ public Map<String, String> getExistingFingerprintsBatch(
274294
}
275295

276296
@Override
277-
public void copyExistingVectorDocuments(
297+
@SuppressWarnings("unchecked")
298+
public boolean copyExistingVectorDocuments(
278299
String sourceIndex, String targetIndex, String parentId, String fingerprint) {
279300
try {
280301
String searchQuery =
@@ -286,15 +307,17 @@ public void copyExistingVectorDocuments(
286307
JsonNode hits = root.path("hits").path("hits");
287308

288309
if (!hits.isArray() || hits.isEmpty()) {
289-
return;
310+
return false;
290311
}
291312

292313
List<Map<String, Object>> docs = new ArrayList<>();
293314
for (JsonNode hit : hits) {
294315
Map<String, Object> source = MAPPER.convertValue(hit.path("_source"), Map.class);
316+
source.put("fingerprint", fingerprint);
295317
docs.add(source);
296318
}
297319
bulkIndex(docs, targetIndex);
320+
return true;
298321
} catch (Exception e) {
299322
LOG.error(
300323
"Failed to copy vector documents from {} to {} for parent_id={}: {}",
@@ -303,6 +326,7 @@ public void copyExistingVectorDocuments(
303326
parentId,
304327
e.getMessage(),
305328
e);
329+
return false;
306330
}
307331
}
308332

@@ -440,7 +464,7 @@ public void bulkIndex(List<Map<String, Object>> documents, String targetIndex) {
440464
Map<String, Object> doc = documents.get(i);
441465
String parentId = (String) doc.get("parent_id");
442466
int chunkIndex = doc.containsKey("chunk_index") ? (int) doc.get("chunk_index") : i;
443-
String docId = parentId + "_" + chunkIndex;
467+
String docId = parentId + "-" + chunkIndex;
444468

445469
operations.add(
446470
BulkOperation.of(

0 commit comments

Comments
 (0)