Skip to content

Commit e70fcd6

Browse files
authored
Fixes #4233: Improve Weaviate error handling (#4239)
* Fixes #4233: Improve Weaviate error handling * fix tests
1 parent b4ebf3f commit e70fcd6

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

extended-it/src/test/java/apoc/vectordb/WeaviateTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,19 @@ public void queryReadOnlyVectorsWithMapping() {
455455
);
456456
}
457457

458+
@Test
459+
public void queryWithWrongEmbeddingSize() {
460+
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
461+
FIELDS_KEY, FIELDS,
462+
HEADERS_KEY, READONLY_AUTHORIZATION);
463+
464+
String expectedErrMsg = "distance between entrypoint and query node: vector lengths don't match: 4 vs 3";
465+
466+
assertFails(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9], null, 5, $conf)",
467+
map("host", HOST, "conf", conf),
468+
expectedErrMsg);
469+
}
470+
458471
@Test
459472
public void queryVectorsWithCreateRelWithoutVectorResult() {
460473

extended/src/main/java/apoc/vectordb/Weaviate.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import apoc.ml.RestAPIConfig;
55
import apoc.result.ListResult;
66
import apoc.result.MapResult;
7+
import apoc.util.CollectionUtils;
78
import apoc.util.UrlResolver;
9+
import org.apache.commons.lang3.StringUtils;
810
import org.neo4j.graphdb.GraphDatabaseService;
911
import org.neo4j.graphdb.Transaction;
1012
import org.neo4j.graphdb.security.URLAccessChecker;
@@ -240,7 +242,13 @@ private Stream<EmbeddingResult> queryCommon(String hostOrKey, String collection,
240242

241243
return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx,
242244
v -> {
243-
Object getValue = ((Map<String, Map>) v).get("data").get("Get");
245+
Map<String, Map> mapResult = (Map<String, Map>) v;
246+
List<Map> errors = (List<Map>) mapResult.get("errors");
247+
if ( CollectionUtils.isNotEmpty(errors) ) {
248+
String message = "An error occurred during Weaviate API response: \n" + StringUtils.join(errors, "\n");
249+
throw new RuntimeException(message);
250+
}
251+
Object getValue = mapResult.get("data").get("Get");
244252
Object collectionValue = ((Map) getValue).get(collection);
245253
return ((List<Map>) collectionValue).stream()
246254
.map(i -> {

0 commit comments

Comments
 (0)