Skip to content

Commit b7697f4

Browse files
authored
Fixes #4232: The apoc.vectordb.configure(WEAVIATE', ..) procedure should append /v1 to url (#4248)
1 parent 8223e27 commit b7697f4

File tree

3 files changed

+83
-36
lines changed

3 files changed

+83
-36
lines changed

extended-it/src/test/java/apoc/vectordb/WeaviateTest.java

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import apoc.util.MapUtil;
55
import apoc.util.TestUtil;
66
import org.junit.AfterClass;
7-
import org.junit.Assume;
87
import org.junit.Before;
98
import org.junit.BeforeClass;
109
import org.junit.ClassRule;
@@ -507,41 +506,21 @@ MAPPING_KEY, map(REL_TYPE, "TEST",
507506
public void queryVectorsWithSystemDbStorage() {
508507
String keyConfig = "weaviate-config-foo";
509508
String baseUrl = "http://" + HOST + "/v1";
510-
Map<String, String> mapping = map(EMBEDDING_KEY, "vect",
511-
NODE_LABEL, "Test",
512-
ENTITY_KEY, "myId",
513-
METADATA_KEY, "foo");
514-
sysDb.executeTransactionally("CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)",
515-
map("vectorName", WEAVIATE.toString(),
516-
"keyConfig", keyConfig,
517-
"databaseName", DEFAULT_DATABASE_NAME,
518-
"conf", map(
519-
"host", baseUrl,
520-
"credentials", ADMIN_KEY,
521-
"mapping", mapping
522-
)
523-
)
524-
);
525-
526-
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
527-
528-
testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)",
529-
map("host", keyConfig,
530-
"conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)
531-
),
532-
r -> {
533-
Map<String, Object> row = r.next();
534-
assertBerlinResult(row, ID_1, NODE);
535-
assertNotNull(row.get("score"));
536-
assertNotNull(row.get("vector"));
509+
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false);
510+
}
537511

538-
row = r.next();
539-
assertLondonResult(row, ID_2, NODE);
540-
assertNotNull(row.get("score"));
541-
assertNotNull(row.get("vector"));
542-
});
512+
@Test
513+
public void queryVectorsWithSystemDbStorageWithUrlWithoutVersion() {
514+
String keyConfig = "weaviate-config-foo";
515+
String baseUrl = "http://" + HOST;
516+
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false);
517+
}
543518

544-
assertNodesCreated(db);
519+
@Test
520+
public void queryVectorsWithSystemDbStorageWithUrlV3Version() {
521+
String keyConfig = "weaviate-config-foo";
522+
String baseUrl = "http://" + HOST + "/v3";
523+
assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, true);
545524
}
546525

547526
@Test
@@ -575,4 +554,56 @@ WITH collect(node) as paths
575554
),
576555
VectorDbTestUtil::assertRagWithVectors);
577556
}
557+
558+
private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, String baseUrl, boolean fails) {
559+
Map<String, String> mapping = map(EMBEDDING_KEY, "vect",
560+
NODE_LABEL, "Test",
561+
ENTITY_KEY, "myId",
562+
METADATA_KEY, "foo");
563+
sysDb.executeTransactionally("CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)",
564+
map("vectorName", WEAVIATE.toString(),
565+
"keyConfig", keyConfig,
566+
"databaseName", DEFAULT_DATABASE_NAME,
567+
"conf", map(
568+
"host", baseUrl,
569+
"credentials", ADMIN_KEY,
570+
"mapping", mapping
571+
)
572+
)
573+
);
574+
575+
db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");
576+
577+
String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)";
578+
Map<String, Object> params = map("host", keyConfig,
579+
"conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)
580+
);
581+
582+
if (fails) {
583+
assertFails(
584+
db,
585+
query,
586+
params,
587+
"Caused by: java.io.FileNotFoundException: http://127.0.0.1:" + HOST.split(":")[1] + "/v3/graphql"
588+
);
589+
return;
590+
}
591+
592+
593+
testResult(db, query,
594+
params,
595+
r -> {
596+
Map<String, Object> row = r.next();
597+
assertBerlinResult(row, ID_1, NODE);
598+
assertNotNull(row.get("score"));
599+
assertNotNull(row.get("vector"));
600+
601+
row = r.next();
602+
assertLondonResult(row, ID_2, NODE);
603+
assertNotNull(row.get("score"));
604+
assertNotNull(row.get("vector"));
605+
});
606+
607+
assertNodesCreated(db);
608+
}
578609
}

extended/src/main/java/apoc/vectordb/VectorDb.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
import static apoc.util.ExtendedUtil.setProperties;
3939
import static apoc.util.JsonUtil.OBJECT_MAPPER;
4040
import static apoc.util.SystemDbUtil.withSystemDb;
41-
import static apoc.vectordb.VectorDbUtil.*;
41+
import static apoc.vectordb.VectorDbUtil.EmbeddingResult;
42+
import static apoc.vectordb.VectorDbUtil.appendVersionUrlIfNeeded;
43+
import static apoc.vectordb.VectorDbUtil.getEndpoint;
4244
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
4345
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
4446

@@ -259,7 +261,7 @@ public void vectordb(
259261
Node node = Util.mergeNode(transaction, label, null, Pair.of(SystemPropertyKeys.name.name(), configKey));
260262

261263
Map mapping = (Map) config.get("mapping");
262-
String host = (String) config.get("host");
264+
String host = appendVersionUrlIfNeeded(type, (String) config.get("host"));
263265
Object credentials = config.get("credentials");
264266

265267
if (host != null) {

extended/src/main/java/apoc/vectordb/VectorDbUtil.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,18 @@ public static void methodAndPayloadNull(Map<String, Object> config) {
135135
config.put(METHOD_KEY, null);
136136
config.put(BODY_KEY, null);
137137
}
138+
139+
/**
140+
* If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number,
141+
* then add `/v1` to the endpoint
142+
*/
143+
public static String appendVersionUrlIfNeeded(VectorDbHandler.Type type, String host) {
144+
if (VectorDbHandler.Type.WEAVIATE == type) {
145+
String regex = ".*(/v\\d+)$";
146+
if (!host.matches(regex)) {
147+
host = host + "/v1";
148+
}
149+
}
150+
return host;
151+
}
138152
}

0 commit comments

Comments
 (0)