values = new ArrayList<>(3);
+ values.addAll(base.toLog());
+ values.add("nprobe=" + nProbe);
+ return values;
}
}
diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java
index 1be4d26..d399962 100644
--- a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java
+++ b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java
@@ -15,34 +15,66 @@
*/
package ru.rt.restream.reindexer.vector.params;
+import lombok.NonNull;
+
/**
* Factories for KnnSearchParams.
*/
public class KnnParams {
+ @Deprecated
public static BaseKnnSearchParam base(int k) {
checkK(k);
- return new BaseKnnSearchParam(k);
+ return new BaseKnnSearchParam(k, null);
}
- public static IndexHnswSearchParam hnsw(int k, int ef) {
+ public static BaseKnnSearchParam base(int k, float radius) {
+ checkK(k);
+ return new BaseKnnSearchParam(k, radius);
+ }
+
+ public static BaseKnnSearchParam k(int k) {
checkK(k);
+ return new BaseKnnSearchParam(k, null);
+ }
+
+ public static BaseKnnSearchParam radius(float radius) {
+ return new BaseKnnSearchParam(null, radius);
+ }
+
+ public static IndexHnswSearchParam hnsw(int k, int ef) {
if (ef < k) {
throw new IllegalArgumentException("Minimal value of 'ef' must be greater than or equal to 'k'");
}
- return new IndexHnswSearchParam(k, ef);
+ return new IndexHnswSearchParam(k(k), ef);
+ }
+
+ public static IndexHnswSearchParam hnsw(@NonNull BaseKnnSearchParam base, int ef) {
+ if (base.getK() != null && ef < base.getK()) {
+ throw new IllegalArgumentException("Minimal value of 'ef' must be greater than or equal to 'k'");
+ }
+ return new IndexHnswSearchParam(base, ef);
}
public static IndexBfSearchParam bf(int k) {
- checkK(k);
- return new IndexBfSearchParam(k);
+ return new IndexBfSearchParam(k(k));
+ }
+
+ public static IndexBfSearchParam bf(@NonNull BaseKnnSearchParam base) {
+ return new IndexBfSearchParam(base);
}
public static IndexIvfSearchParam ivf(int k, int nProbe) {
- checkK(k);
if (nProbe <= 0) {
throw new IllegalArgumentException("Minimal value of 'nProbe' must be greater than 0");
}
- return new IndexIvfSearchParam(k, nProbe);
+ return new IndexIvfSearchParam(k(k), nProbe);
+ }
+
+ public static IndexIvfSearchParam ivf(@NonNull BaseKnnSearchParam base, int nProbe) {
+ if (nProbe <= 0) {
+ throw new IllegalArgumentException("Minimal value of 'nProbe' must be greater than 0");
+ }
+ return new IndexIvfSearchParam(base, nProbe);
}
private static void checkK(int k) {
diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java
index 63fc0d8..9c2558c 100644
--- a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java
+++ b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java
@@ -23,13 +23,6 @@
* Common interface for KNN search parameters.
*/
public interface KnnSearchParam {
- /**
- * K - the maximum number of documents returned from the index for subsequent filtering.
- *
- * Only required parameter for all vector index types.
- */
- int getK();
-
/**
* Utility method for serializing KNN parameters to CJSON avoiding switch.
*/
diff --git a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java
index 26ee895..cd6631a 100644
--- a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java
+++ b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java
@@ -50,10 +50,15 @@ public abstract class FloatVectorBfTest extends DbBaseTest {
private final String namespaceName = "items";
private Namespace vectorNs;
+ private List testItems;
@BeforeEach
public void setUp() {
vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class);
+ testItems = getTestVectorItems();
+ for (VectorItem item : testItems) {
+ db.insert(namespaceName, item);
+ }
}
@Test
@@ -91,16 +96,11 @@ public void testInsertWithWrongVectorSize_isException() {
}
@Test
- public void testSearchWithBaseParams_isOk() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
-
+ public void testSearchWithBaseParamK_isOk() {
List list = db.query(namespaceName, VectorItem.class)
.selectAllFields()
.whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
- KnnParams.base(2))
+ KnnParams.k(2))
.toList();
assertThat(list.size(), is(2));
@@ -111,12 +111,51 @@ public void testSearchWithBaseParams_isOk() {
}
@Test
- public void testSearchWithVecBfParams_isOk() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
+ public void testSearchWithBaseParamRadius_isOk() {
+ List list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
+ KnnParams.radius(0.7f))
+ .toList();
+
+ assertThat(list.size(), is(4));
+ assertThat(list.get(0).getId(), is(18));
+ assertThat(list.get(1).getId(), is(6));
+ assertThat(list.get(2).getId(), is(7));
+ assertThat(list.get(3).getId(), is(8));
+ }
+ @Test
+ public void testSearchWithBaseParamsKAndRadius_isOk() {
+ List list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
+ KnnParams.base(3, 0.7f))
+ .toList();
+
+ // by k (3 records) + by radius (4 records) = 3 records
+ assertThat(list.size(), is(3));
+ assertThat(list.get(0).getId(), is(18));
+ assertThat(list.get(1).getId(), is(6));
+ assertThat(list.get(2).getId(), is(7));
+
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
+ KnnParams.base(5, 0.7f))
+ .toList();
+
+ // by k (5 records) + by radius (4 records) = 4 records
+ assertThat(list.size(), is(4));
+ assertThat(list.get(0).getId(), is(18));
+ assertThat(list.get(1).getId(), is(6));
+ assertThat(list.get(2).getId(), is(7));
+ assertThat(list.get(3).getId(), is(8));
+ }
+
+ @Test
+ public void testSearchWithVecBfParams_isOk() {
+ // only k - 3 records
List list = db.query(namespaceName, VectorItem.class)
.selectAllFields()
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
@@ -130,15 +169,42 @@ public void testSearchWithVecBfParams_isOk() {
assertThat(list.get(1).getVector(), is(testItems.get(18).getVector()));
assertThat(list.get(2).getId(), is(19));
assertThat(list.get(2).getVector(), is(testItems.get(19).getVector()));
+
+ // only radius 0.7 - 5 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
+ KnnParams.bf(KnnParams.radius(0.7f)))
+ .toList();
+
+ assertThat(list.size(), is(5));
+ assertThat(list.get(0).getId(), is(8));
+ assertThat(list.get(1).getId(), is(18));
+ assertThat(list.get(2).getId(), is(19));
+ assertThat(list.get(3).getId(), is(1));
+ assertThat(list.get(4).getId(), is(2));
+
+ // by k (3 records) + by radius (5 records) = 3 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
+ KnnParams.bf(KnnParams.base(3, 0.7f)))
+ .toList();
+
+ assertThat(list.size(), is(3));
+
+ // by k (6 records) + by radius (5 records) = 5 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
+ KnnParams.bf(KnnParams.base(6, 0.7f)))
+ .toList();
+
+ assertThat(list.size(), is(5));
}
@Test
public void testSearchWithIncorrectVecBfParams_isException() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
-
assertThrows(IllegalArgumentException.class,
() -> db.query(namespaceName, VectorItem.class)
.selectAllFields()
@@ -149,12 +215,17 @@ public void testSearchWithIncorrectVecBfParams_isException() {
}
@Test
- public void testSearchWithNotVecBfNorBaseParams_isException() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
+ public void testTrySearchWithVecBfParamFromNullBaseParam_isException() {
+ assertThrows(NullPointerException.class,
+ () -> db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f},
+ KnnParams.bf(null))
+ .toList());
+ }
+ @Test
+ public void testSearchWithNotVecBfNorBaseParams_isException() {
assertThrows(RuntimeException.class,
() -> db.query(namespaceName, VectorItem.class)
.selectAllFields()
diff --git a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java
index d4e76de..486371c 100644
--- a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java
+++ b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java
@@ -50,10 +50,15 @@ public abstract class FloatVectorHnswTest extends DbBaseTest {
private final String namespaceName = "items";
private Namespace vectorNs;
+ private List testItems;
@BeforeEach
public void setUp() {
vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class);
+ testItems = getTestVectorItems();
+ for (VectorItem item : testItems) {
+ db.insert(namespaceName, item);
+ }
}
@Test
@@ -91,16 +96,11 @@ public void testInsertWithWrongVectorSize_isException() {
}
@Test
- public void testSearchWithBaseParams_isOk() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
-
+ public void testSearchWithBaseParamK_isOk() {
List list = db.query(namespaceName, VectorItem.class)
.selectAllFields()
.whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
- KnnParams.base(2))
+ KnnParams.k(2))
.toList();
assertThat(list.size(), is(2));
@@ -111,16 +111,55 @@ public void testSearchWithBaseParams_isOk() {
}
@Test
- public void testSearchWithHnswParams_isOk() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
+ public void testSearchWithBaseParamRadius_isOk() {
+ List list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
+ KnnParams.radius(0.4f))
+ .toList();
+
+ assertThat(list.size(), is(4));
+ assertThat(list.get(0).getId(), is(1));
+ assertThat(list.get(1).getId(), is(2));
+ assertThat(list.get(2).getId(), is(0));
+ assertThat(list.get(3).getId(), is(3));
+ }
+ @Test
+ public void testSearchWithBaseParamsKAndRadius_isOk() {
+ // by k (3 records) + by radius (4 records) = 3 records
+ List list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
+ KnnParams.base(3, 0.4f))
+ .toList();
+
+ assertThat(list.size(), is(3));
+ assertThat(list.get(0).getId(), is(1));
+ assertThat(list.get(1).getId(), is(2));
+ assertThat(list.get(2).getId(), is(0));
+
+ // by k (5 records) + by radius (4 records) = 4 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f},
+ KnnParams.base(5, 0.4f))
+ .toList();
+
+ assertThat(list.size(), is(4));
+ assertThat(list.get(0).getId(), is(1));
+ assertThat(list.get(1).getId(), is(2));
+ assertThat(list.get(2).getId(), is(0));
+ assertThat(list.get(3).getId(), is(3));
+ }
+
+ @Test
+ public void testSearchWithHnswParams_isOk() {
+ // k - 2 records
List list = db.query(namespaceName, VectorItem.class)
.selectAllFields()
.whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
- KnnParams.hnsw(2, 2))
+ KnnParams.hnsw(2, 5))
.toList();
assertThat(list.size(), is(2));
@@ -128,15 +167,47 @@ public void testSearchWithHnswParams_isOk() {
assertThat(list.get(0).getVector(), is(testItems.get(2).getVector()));
assertThat(list.get(1).getId(), is(3));
assertThat(list.get(1).getVector(), is(testItems.get(3).getVector()));
+
+ // radius 0.4 - 4 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
+ KnnParams.hnsw(KnnParams.radius(0.4f), 5))
+ .toList();
+
+ assertThat(list.size(), is(4));
+
+ // by k (3 records) + by radius (4 records) = 3 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
+ KnnParams.hnsw(KnnParams.base(3, 0.4f), 5))
+ .toList();
+
+ assertThat(list.size(), is(3));
+
+ // by k (5 records) + by radius (4 records) = 4 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f},
+ KnnParams.hnsw(KnnParams.base(5, 0.4f), 5))
+ .toList();
+
+ assertThat(list.size(), is(4));
}
@Test
- public void testSearchWithIncorrectHnswParams_isException() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
+ public void testTrySearchWithHnswParamFromNullBaseParam_isException() {
+ assertThrows(NullPointerException.class,
+ () -> db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f},
+ KnnParams.hnsw(null, 5))
+ .toList());
+ }
+ @Test
+ public void testSearchWithIncorrectHnswParams_isException() {
assertThrows(IllegalArgumentException.class,
() -> db.query(namespaceName, VectorItem.class)
.selectAllFields()
@@ -156,11 +227,6 @@ public void testSearchWithIncorrectHnswParams_isException() {
@Test
public void testSearchWithNotHnswNorBaseParams_isException() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
-
assertThrows(RuntimeException.class,
() -> db.query(namespaceName, VectorItem.class)
.selectAllFields()
diff --git a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java
index b6e6cfa..9d36acb 100644
--- a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java
+++ b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java
@@ -41,6 +41,7 @@
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static ru.rt.restream.reindexer.Query.Condition.EQ;
+import static ru.rt.restream.reindexer.vector.params.KnnParams.radius;
import static ru.rt.restream.util.ReindexerUtils.getIndexByName;
/**
@@ -50,10 +51,15 @@ public abstract class FloatVectorIvfTest extends DbBaseTest {
private final String namespaceName = "items";
private Namespace vectorNs;
+ private List testItems;
@BeforeEach
public void setUp() {
vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class);
+ testItems = getTestVectorItems();
+ for (VectorItem item : testItems) {
+ db.insert(namespaceName, item);
+ }
}
@Test
@@ -91,16 +97,11 @@ public void testInsertWithWrongVectorSize_isException() {
}
@Test
- public void testSearchWithBaseParams_isOk() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
-
+ public void testSearchWithBaseParamK_isOk() {
List list = db.query(namespaceName, VectorItem.class)
.selectAllFields()
.whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
- KnnParams.base(2))
+ KnnParams.k(2))
.toList();
assertThat(list.size(), is(2));
@@ -111,34 +112,106 @@ public void testSearchWithBaseParams_isOk() {
}
@Test
- public void testSearchWithIvfParams_isOk() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
+ public void testSearchWithBaseParamRadius_isOk() {
+ List list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
+ radius(0.1f))
+ .toList();
+ assertThat(list.size(), is(4));
+ assertThat(list.get(0).getId(), is(18));
+ assertThat(list.get(1).getId(), is(6));
+ assertThat(list.get(2).getId(), is(7));
+ assertThat(list.get(3).getId(), is(8));
+ }
+
+ @Test
+ public void testSearchWithBaseParamsKAndRadius_isOk() {
+ // by k (3 records) + by radius (4 records) = 3 records
List list = db.query(namespaceName, VectorItem.class)
.selectAllFields()
- .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
- KnnParams.ivf(3, 3))
+ .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
+ KnnParams.base(3, 0.1f))
.toList();
assertThat(list.size(), is(3));
+ assertThat(list.get(0).getId(), is(18));
+ assertThat(list.get(1).getId(), is(7));
+ assertThat(list.get(2).getId(), is(8));
+
+ // by k (5 records) + by radius (4 records) = 4 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f},
+ KnnParams.base(5, 0.1f))
+ .toList();
+
+ assertThat(list.size(), is(4));
+ assertThat(list.get(0).getId(), is(18));
+ assertThat(list.get(1).getId(), is(6));
+ assertThat(list.get(2).getId(), is(7));
+ assertThat(list.get(3).getId(), is(8));
+ }
+
+ @Test
+ public void testSearchWithIvfParams_isOk() {
+ // only k - 2 records
+ List list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
+ KnnParams.ivf(2, 3))
+ .toList();
+
+ assertThat(list.size(), is(2));
assertThat(list.get(0).getId(), is(8));
assertThat(list.get(0).getVector(), is(testItems.get(8).getVector()));
assertThat(list.get(1).getId(), is(18));
assertThat(list.get(1).getVector(), is(testItems.get(18).getVector()));
+
+ // only radius - 3 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
+ KnnParams.ivf(KnnParams.radius(0.3f), 3))
+ .toList();
+
+ assertThat(list.size(), is(3));
+ assertThat(list.get(0).getId(), is(8));
+ assertThat(list.get(1).getId(), is(18));
assertThat(list.get(2).getId(), is(19));
- assertThat(list.get(2).getVector(), is(testItems.get(19).getVector()));
+
+ // by k (2 records) + by radius (3 records) = 2 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
+ KnnParams.ivf(KnnParams.base(2, 0.3f), 3))
+ .toList();
+
+ assertThat(list.size(), is(2));
+
+ // by k (4 records) + by radius (3 records) = 3 records
+ list = db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f},
+ KnnParams.ivf(KnnParams.base(4, 0.3f), 3))
+ .toList();
+
+ assertThat(list.size(), is(3));
}
@Test
- public void testSearchWithIncorrectIvfParams_isException() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
+ public void testTrySearchWithIvfParamFromNullBaseParam_isException() {
+ assertThrows(NullPointerException.class,
+ () -> db.query(namespaceName, VectorItem.class)
+ .selectAllFields()
+ .whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f},
+ KnnParams.bf(null))
+ .toList());
+ }
+ @Test
+ public void testSearchWithIncorrectIvfParams_isException() {
assertThrows(IllegalArgumentException.class,
() -> db.query(namespaceName, VectorItem.class)
.selectAllFields()
@@ -158,11 +231,6 @@ public void testSearchWithIncorrectIvfParams_isException() {
@Test
public void testSearchWithNotIvfNorBaseParams_isException() {
- List testItems = getTestVectorItems();
- for (VectorItem item : testItems) {
- db.insert(namespaceName, item);
- }
-
assertThrows(RuntimeException.class,
() -> db.query(namespaceName, VectorItem.class)
.selectAllFields()