Skip to content

Commit 5748e1c

Browse files
authored
Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb.weaviate.query* procedures should get the fields config from metadataKey if present (#4241)
* Fixes #4231: The apoc.vectordb.milvus.query* and apoc.vectordb.weaviate.query* procedures should get the fields config from metadataKey if present * test fixes and changes review * fix tests
1 parent 00d75d3 commit 5748e1c

File tree

7 files changed

+170
-10
lines changed

7 files changed

+170
-10
lines changed

extended-it/src/test/java/apoc/vectordb/MilvusTest.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package apoc.vectordb;
22

33
import apoc.ml.Prompt;
4+
import apoc.util.ExtendedTestUtil;
45
import apoc.util.TestUtil;
56
import apoc.util.Util;
67
import org.junit.AfterClass;
@@ -42,11 +43,14 @@
4243
import static apoc.vectordb.VectorMappingConfig.METADATA_KEY;
4344
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
4445
import static apoc.vectordb.VectorMappingConfig.MappingMode;
46+
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;
4547
import static apoc.vectordb.VectorMappingConfig.NODE_LABEL;
4648
import static apoc.vectordb.VectorMappingConfig.REL_TYPE;
49+
import static org.assertj.core.api.Assertions.assertThat;
4750
import static org.junit.Assert.assertEquals;
4851
import static org.junit.Assert.assertNotNull;
4952
import static org.junit.Assert.assertNull;
53+
import static org.junit.Assert.fail;
5054
import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
5155
import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME;
5256

@@ -488,4 +492,68 @@ WITH collect(node) as paths
488492
VectorDbTestUtil::assertRagWithVectors);
489493
}
490494

495+
@Test
496+
public void queryVectorsWithMetadataKeyNoFields() {
497+
Map<String, Object> conf = map(
498+
ALL_RESULTS_KEY, true,
499+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
500+
REL_TYPE, "TEST",
501+
ENTITY_KEY, "readID",
502+
METADATA_KEY, "foo"
503+
)
504+
);
505+
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
506+
map("host", HOST, "conf", conf),
507+
VectorDbTestUtil::assertMetadataFooResult);
508+
}
509+
510+
@Test
511+
public void queryVectorsWithNoMetadataKeyNoFields() {
512+
Map<String, Object> params = map(
513+
"host", HOST, "conf", Map.of(
514+
ALL_RESULTS_KEY, true,
515+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
516+
REL_TYPE, "TEST",
517+
ENTITY_KEY, "readID"
518+
))
519+
);
520+
String query = "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
521+
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
522+
}
523+
524+
@Test
525+
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
526+
Map<String, Object> conf = map(
527+
ALL_RESULTS_KEY, true,
528+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
529+
REL_TYPE, "TEST",
530+
ENTITY_KEY, "readID",
531+
METADATA_KEY, "foo"
532+
)
533+
);
534+
535+
String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
536+
537+
testResult(db, query,
538+
map("host", HOST, "conf", conf),
539+
VectorDbTestUtil::assertMetadataFooResult);
540+
}
541+
542+
@Test
543+
public void queryAndUpdateWithNoMetadataKeyNoFields() {
544+
Map<String, Object> conf = map(
545+
ALL_RESULTS_KEY, true,
546+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
547+
REL_TYPE, "TEST",
548+
ENTITY_KEY, "readID"
549+
)
550+
);
551+
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
552+
Map<String, Object> params = Util.map("host", HOST,
553+
"conf", conf);
554+
555+
String query = "CALL apoc.vectordb.milvus.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
556+
557+
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
558+
}
491559
}

extended-it/src/test/java/apoc/vectordb/WeaviateTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package apoc.vectordb;
22

33
import apoc.ml.Prompt;
4+
import apoc.util.ExtendedTestUtil;
45
import apoc.util.MapUtil;
56
import apoc.util.TestUtil;
67
import org.junit.AfterClass;
@@ -606,4 +607,61 @@ private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, Stri
606607

607608
assertNodesCreated(db);
608609
}
610+
611+
@Test
612+
public void queryVectorsWithMetadataKeyNoFields() {
613+
testResult(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
614+
" YIELD score, vector, id, metadata RETURN * ORDER BY id",
615+
map("host", HOST, "conf", map(
616+
ALL_RESULTS_KEY, true,
617+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
618+
NODE_LABEL, "Test",
619+
ENTITY_KEY, "myId",
620+
METADATA_KEY, "foo"
621+
),
622+
HEADERS_KEY, ADMIN_AUTHORIZATION)),
623+
VectorDbTestUtil::assertMetadataFooResult);
624+
}
625+
626+
@Test
627+
public void queryVectorsWithNoMetadataKeyNoFields() {
628+
Map<String, Object> params = map("host", HOST, "conf", map(
629+
ALL_RESULTS_KEY, true,
630+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
631+
NODE_LABEL, "Test",
632+
ENTITY_KEY, "myId"
633+
),
634+
HEADERS_KEY, ADMIN_AUTHORIZATION));
635+
String query = "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
636+
" YIELD score, vector, id, metadata RETURN * ORDER BY id";
637+
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
638+
}
639+
640+
@Test
641+
public void queryAndUpdateMetadataKeyWithoutFieldsTest() {
642+
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
643+
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
644+
HEADERS_KEY, ADMIN_AUTHORIZATION,
645+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
646+
NODE_LABEL, "Test",
647+
ENTITY_KEY, "myId",
648+
METADATA_KEY, "foo"));
649+
testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " +
650+
" YIELD score, vector, id, metadata, node RETURN * ORDER BY id",
651+
map("host", HOST, "conf", conf),
652+
VectorDbTestUtil::assertMetadataFooResult);
653+
}
654+
655+
@Test
656+
public void queryAndUpdateWithCreateNodeUsingExistingNodeFailWithNoMetadataKeyAndNoFields() {
657+
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
658+
Map<String, Object> params = map("host", HOST,
659+
"conf", Map.of(ALL_RESULTS_KEY, true,
660+
HEADERS_KEY, ADMIN_AUTHORIZATION,
661+
MAPPING_KEY, map(EMBEDDING_KEY, "vect",
662+
NODE_LABEL, "Test",
663+
ENTITY_KEY, "myId")));
664+
String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) YIELD score, vector, id, metadata, node RETURN * ORDER BY id";
665+
ExtendedTestUtil.assertFails(db, query, params, NO_FIELDS_ERROR_MSG);
666+
}
609667
}

extended/src/main/java/apoc/vectordb/MilvusHandler.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import java.util.Map;
88

99
import static apoc.util.MapUtil.map;
10-
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
10+
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
1111
import static apoc.vectordb.VectorEmbeddingConfig.META_AS_SUBKEY_KEY;
1212
import static apoc.vectordb.VectorEmbeddingConfig.SCORE_KEY;
1313

@@ -57,10 +57,8 @@ public VectorEmbeddingConfig fromQuery(Map<String, Object> config, ProcedureCall
5757
private VectorEmbeddingConfig getVectorEmbeddingConfig(Map<String, Object> config, List<String> procFields, String collection, Map<String, Object> additionalBodies) {
5858
config.putIfAbsent(META_AS_SUBKEY_KEY, false);
5959

60-
List listFields = (List) config.get(FIELDS_KEY);
61-
if (listFields == null) {
62-
throw new RuntimeException("You have to define `field` list of parameter to be returned");
63-
}
60+
List listFields = addMetadataKeyToFields(config);
61+
6462
if (procFields.contains("vector") && !listFields.contains("vector")) {
6563
listFields.add("vector");
6664
}

extended/src/main/java/apoc/vectordb/VectorDbUtil.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33

44
import apoc.ExtendedSystemPropertyKeys;
55
import apoc.SystemPropertyKeys;
6+
import apoc.util.CollectionUtils;
67
import apoc.util.ExtendedMapUtils;
78
import apoc.util.Util;
9+
import org.apache.commons.lang3.StringUtils;
810
import org.neo4j.graphdb.Label;
911
import org.neo4j.graphdb.Node;
1012
import org.neo4j.graphdb.Relationship;
1113

1214
import java.net.HttpURLConnection;
1315
import java.net.URL;
16+
import java.util.ArrayList;
1417
import java.util.HashMap;
1518
import java.util.List;
1619
import java.util.Map;
@@ -20,9 +23,12 @@
2023
import static apoc.ml.RestAPIConfig.ENDPOINT_KEY;
2124
import static apoc.ml.RestAPIConfig.METHOD_KEY;
2225
import static apoc.util.SystemDbUtil.withSystemDb;
26+
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
2327
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
28+
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
2429
import static apoc.vectordb.VectorMappingConfig.MODE_KEY;
2530
import static apoc.vectordb.VectorMappingConfig.MappingMode.READ_ONLY;
31+
import static apoc.vectordb.VectorMappingConfig.NO_FIELDS_ERROR_MSG;
2632

2733
public class VectorDbUtil {
2834

@@ -136,6 +142,26 @@ public static void methodAndPayloadNull(Map<String, Object> config) {
136142
config.put(BODY_KEY, null);
137143
}
138144

145+
public static List addMetadataKeyToFields(Map<String, Object> config) {
146+
List listFields = (List) config.getOrDefault(FIELDS_KEY, new ArrayList<>());
147+
148+
Map<String, Object> mapping = (Map<String, Object>) config.get(MAPPING_KEY);
149+
150+
String metadataKey = mapping == null
151+
? null
152+
: (String) mapping.get(METADATA_KEY);
153+
154+
if (CollectionUtils.isEmpty(listFields)) {
155+
156+
if (StringUtils.isEmpty(metadataKey)) {
157+
throw new RuntimeException(NO_FIELDS_ERROR_MSG);
158+
}
159+
listFields.add(metadataKey);
160+
}
161+
162+
return listFields;
163+
}
164+
139165
/**
140166
* If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number,
141167
* then add `/v1` to the endpoint

extended/src/main/java/apoc/vectordb/VectorMappingConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ enum MappingMode {
1515
public static final String EMBEDDING_KEY = "embeddingKey";
1616
public static final String SIMILARITY_KEY = "similarity";
1717
public static final String MODE_KEY = "mode";
18+
public static final String NO_FIELDS_ERROR_MSG = "You need to define either the 'field' list parameter, or the 'metadataKey' string parameter within the `embeddingConfig` parameter";
1819

1920
private final String metadataKey;
2021
private final String entityKey;

extended/src/main/java/apoc/vectordb/WeaviateHandler.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import static apoc.ml.RestAPIConfig.BODY_KEY;
1010
import static apoc.ml.RestAPIConfig.METHOD_KEY;
1111
import static apoc.util.MapUtil.map;
12-
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
12+
import static apoc.vectordb.VectorDbUtil.addMetadataKeyToFields;
1313
import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY;
1414
import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY;
1515

@@ -47,10 +47,8 @@ public VectorEmbeddingConfig fromQuery(Map<String, Object> config, ProcedureCall
4747
config.putIfAbsent(METHOD_KEY, "POST");
4848
VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config);
4949

50-
List list = (List) config.get(FIELDS_KEY);
51-
if (list == null) {
52-
throw new RuntimeException("You have to define `field` list of parameter to be returned");
53-
}
50+
List list = addMetadataKeyToFields(config);
51+
5452
Object fieldList = String.join("\n", list);
5553

5654
filter = filter == null

extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import java.util.Map;
1111

12+
import static apoc.vectordb.VectorEmbeddingConfig.DEFAULT_METADATA;
1213
import static apoc.util.TestUtil.testResult;
1314
import static apoc.util.Util.map;
1415
import static org.junit.Assert.assertEquals;
@@ -114,4 +115,14 @@ public static String ragSetup(GraphDatabaseService db) {
114115
db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");
115116
return openAIKey;
116117
}
118+
119+
public static void assertMetadataFooResult(Result r) {
120+
Map<String, Object> row = r.next();
121+
Map<String, Object> metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
122+
assertEquals("one", metadata.get("foo"));
123+
row = r.next();
124+
metadata = (Map<String, Object>) row.get(DEFAULT_METADATA);
125+
assertEquals("two", metadata.get("foo"));
126+
assertFalse(r.hasNext());
127+
}
117128
}

0 commit comments

Comments
 (0)