diff --git a/src/main/java/ru/rt/restream/reindexer/binding/Consts.java b/src/main/java/ru/rt/restream/reindexer/binding/Consts.java index 4d0a68c..d4ae0e9 100644 --- a/src/main/java/ru/rt/restream/reindexer/binding/Consts.java +++ b/src/main/java/ru/rt/restream/reindexer/binding/Consts.java @@ -20,7 +20,7 @@ */ public final class Consts { - public static final String REINDEXER_VERSION = "v5.0.0"; + public static final String REINDEXER_VERSION = "v5.4.0"; public static final String DEF_APP_NAME = "java-connector"; public static final String APP_PROPERTY_NAME = "app.name"; @@ -84,7 +84,7 @@ public final class Consts { public static final int KNN_QUERY_TYPE_HNSW = 2; public static final int KNN_QUERY_TYPE_IVF = 3; - public static final int KNN_QUERY_PARAMS_VERSION = 0; + public static final int KNN_QUERY_PARAMS_VERSION = 1; public static final int RESULTS_FORMAT_MASK = 0xF; public static final int RESULTS_PURE = 0x0; diff --git a/src/main/java/ru/rt/restream/reindexer/binding/cproto/ByteBuffer.java b/src/main/java/ru/rt/restream/reindexer/binding/cproto/ByteBuffer.java index 4cbe5af..95743a2 100644 --- a/src/main/java/ru/rt/restream/reindexer/binding/cproto/ByteBuffer.java +++ b/src/main/java/ru/rt/restream/reindexer/binding/cproto/ByteBuffer.java @@ -97,6 +97,22 @@ public ByteBuffer(float expandFactor, int initialCapacity) { buffer = new byte[initialCapacity]; } + + /** + * Encodes an integer value into unsigned 8-bit integer. + * Increments buffer position. + * + * @param value value to encode + * @return the {@link ByteBuffer} for further customizations + */ + public ByteBuffer putUInt8(int value) { + if (value < 0 || value > 0xFF) { + throw new IllegalArgumentException(); + } + putIntBits(value, Byte.BYTES, -1); + return this; + } + /** * Encodes an integer value into unsigned 16-bit integer. * Increments buffer position. diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/BaseKnnSearchParam.java b/src/main/java/ru/rt/restream/reindexer/vector/params/BaseKnnSearchParam.java index e321cad..a644532 100644 --- a/src/main/java/ru/rt/restream/reindexer/vector/params/BaseKnnSearchParam.java +++ b/src/main/java/ru/rt/restream/reindexer/vector/params/BaseKnnSearchParam.java @@ -20,19 +20,36 @@ import lombok.Getter; import ru.rt.restream.reindexer.binding.cproto.ByteBuffer; -import java.util.Collections; +import java.util.ArrayList; import java.util.List; import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION; import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_TYPE_BASE; +/** + * Common parameters for all types of KNN indices. + * + *

If all parameters are specified, the filtering will be performed in such a way that all conditions are met. + * At least one of these parameters must be specified. + */ @Getter @AllArgsConstructor(access = AccessLevel.PACKAGE) public class BaseKnnSearchParam implements KnnSearchParam { + private static final int KNN_SERIALIZE_WITH_K = 1; + private static final int KNN_SERIALIZE_WITH_RADIUS = 1 << 1; + /** * The maximum number of documents returned from the index for subsequent filtering. */ - private final int k; + private final Integer k; + /** + * Parameter for filtering vectors by ranks. + * + *

Rank() < radius for L2 metrics and rank() > radius for cosine and inner product metrics. + * About default values and usage see + * + */ + private final Float radius; /** * {@inheritDoc} @@ -40,8 +57,35 @@ public class BaseKnnSearchParam implements KnnSearchParam { @Override public void serializeBy(ByteBuffer buffer) { buffer.putVarUInt32(KNN_QUERY_TYPE_BASE) - .putVarUInt32(KNN_QUERY_PARAMS_VERSION) - .putVarUInt32(k); + .putVarUInt32(KNN_QUERY_PARAMS_VERSION); + serializeKAndRadius(buffer); + } + + void serializeKAndRadius(ByteBuffer buffer) { + checkValues(); + int mask = 0; + if (k != null) { + mask |= KNN_SERIALIZE_WITH_K; + } + if (radius != null) { + mask |= KNN_SERIALIZE_WITH_RADIUS; + } + buffer.putUInt8(mask); + if (k != null) { + buffer.putVarUInt32(k); + } + if (radius != null) { + buffer.putFloat(radius); + } + } + + private void checkValues() { + if (k == null && radius == null) { + throw new IllegalArgumentException("Both params (k and radius) cannot be null"); + } + if (k != null && k <= 0) { + throw new IllegalArgumentException("'k' must be greater than 0"); + } } /** @@ -49,6 +93,13 @@ public void serializeBy(ByteBuffer buffer) { */ @Override public List toLog() { - return Collections.singletonList("k=" + k); + List values = new ArrayList<>(2); + if (k != null) { + values.add("k=" + k); + } + if (radius != null) { + values.add("radius=" + radius); + } + return values; } } diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/IndexBfSearchParam.java b/src/main/java/ru/rt/restream/reindexer/vector/params/IndexBfSearchParam.java index 917bee8..a146666 100644 --- a/src/main/java/ru/rt/restream/reindexer/vector/params/IndexBfSearchParam.java +++ b/src/main/java/ru/rt/restream/reindexer/vector/params/IndexBfSearchParam.java @@ -18,10 +18,10 @@ import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.NonNull; import ru.rt.restream.reindexer.binding.cproto.ByteBuffer; -import java.util.Arrays; -import java.util.Collections; +import java.util.ArrayList; import java.util.List; import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION; @@ -31,9 +31,10 @@ @AllArgsConstructor(access = AccessLevel.PACKAGE) public class IndexBfSearchParam implements KnnSearchParam { /** - * The maximum number of documents returned from the index for subsequent filtering. + * Common parameters for KNN search. */ - private final int k; + @NonNull + private final BaseKnnSearchParam base; /** * {@inheritDoc} @@ -41,8 +42,8 @@ public class IndexBfSearchParam implements KnnSearchParam { @Override public void serializeBy(ByteBuffer buffer) { buffer.putVarUInt32(KNN_QUERY_TYPE_BRUTE_FORCE) - .putVarUInt32(KNN_QUERY_PARAMS_VERSION) - .putVarUInt32(k); + .putVarUInt32(KNN_QUERY_PARAMS_VERSION); + base.serializeKAndRadius(buffer); } /** @@ -50,6 +51,6 @@ public void serializeBy(ByteBuffer buffer) { */ @Override public List toLog() { - return Collections.singletonList("k=" + k); + return new ArrayList<>(base.toLog()); } } diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/IndexHnswSearchParam.java b/src/main/java/ru/rt/restream/reindexer/vector/params/IndexHnswSearchParam.java index 0bb57de..5144c71 100644 --- a/src/main/java/ru/rt/restream/reindexer/vector/params/IndexHnswSearchParam.java +++ b/src/main/java/ru/rt/restream/reindexer/vector/params/IndexHnswSearchParam.java @@ -18,10 +18,10 @@ import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.NonNull; import ru.rt.restream.reindexer.binding.cproto.ByteBuffer; -import java.util.Arrays; -import java.util.Collections; +import java.util.ArrayList; import java.util.List; import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION; @@ -31,9 +31,10 @@ @AllArgsConstructor(access = AccessLevel.PACKAGE) public class IndexHnswSearchParam implements KnnSearchParam { /** - * The maximum number of documents returned from the index for subsequent filtering. + * Common parameters for KNN search. */ - private final int k; + @NonNull + private final BaseKnnSearchParam base; /** * The size of the dynamic list for the nearest neighbors. @@ -50,9 +51,9 @@ public class IndexHnswSearchParam implements KnnSearchParam { @Override public void serializeBy(ByteBuffer buffer) { buffer.putVarUInt32(KNN_QUERY_TYPE_HNSW) - .putVarUInt32(KNN_QUERY_PARAMS_VERSION) - .putVarUInt32(k) - .putVarInt32(ef); + .putVarUInt32(KNN_QUERY_PARAMS_VERSION); + base.serializeKAndRadius(buffer); + buffer.putVarInt32(ef); } /** @@ -60,6 +61,9 @@ public void serializeBy(ByteBuffer buffer) { */ @Override public List toLog() { - return Arrays.asList("k=" + k, "ef=" + ef); + List values = new ArrayList<>(3); + values.addAll(base.toLog()); + values.add("ef=" + ef); + return values; } } diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/IndexIvfSearchParam.java b/src/main/java/ru/rt/restream/reindexer/vector/params/IndexIvfSearchParam.java index e1429a4..bcfc89f 100644 --- a/src/main/java/ru/rt/restream/reindexer/vector/params/IndexIvfSearchParam.java +++ b/src/main/java/ru/rt/restream/reindexer/vector/params/IndexIvfSearchParam.java @@ -18,9 +18,10 @@ import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.NonNull; import ru.rt.restream.reindexer.binding.cproto.ByteBuffer; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import static ru.rt.restream.reindexer.binding.Consts.KNN_QUERY_PARAMS_VERSION; @@ -30,9 +31,10 @@ @AllArgsConstructor(access = AccessLevel.PACKAGE) public class IndexIvfSearchParam implements KnnSearchParam { /** - * The maximum number of documents returned from the index for subsequent filtering. + * Common parameters for KNN search. */ - private final int k; + @NonNull + private final BaseKnnSearchParam base; /** * the number of clusters to be looked at during the search. @@ -49,9 +51,9 @@ public class IndexIvfSearchParam implements KnnSearchParam { @Override public void serializeBy(ByteBuffer buffer) { buffer.putVarUInt32(KNN_QUERY_TYPE_IVF) - .putVarUInt32(KNN_QUERY_PARAMS_VERSION) - .putVarUInt32(k) - .putVarUInt32(nProbe); + .putVarUInt32(KNN_QUERY_PARAMS_VERSION); + base.serializeKAndRadius(buffer); + buffer.putVarUInt32(nProbe); } /** @@ -59,6 +61,9 @@ public void serializeBy(ByteBuffer buffer) { */ @Override public List toLog() { - return Arrays.asList("k=" + k, "nprobe=" + nProbe); + List values = new ArrayList<>(3); + values.addAll(base.toLog()); + values.add("nprobe=" + nProbe); + return values; } } diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java index 1be4d26..d399962 100644 --- a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java +++ b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnParams.java @@ -15,34 +15,66 @@ */ package ru.rt.restream.reindexer.vector.params; +import lombok.NonNull; + /** * Factories for KnnSearchParams. */ public class KnnParams { + @Deprecated public static BaseKnnSearchParam base(int k) { checkK(k); - return new BaseKnnSearchParam(k); + return new BaseKnnSearchParam(k, null); } - public static IndexHnswSearchParam hnsw(int k, int ef) { + public static BaseKnnSearchParam base(int k, float radius) { + checkK(k); + return new BaseKnnSearchParam(k, radius); + } + + public static BaseKnnSearchParam k(int k) { checkK(k); + return new BaseKnnSearchParam(k, null); + } + + public static BaseKnnSearchParam radius(float radius) { + return new BaseKnnSearchParam(null, radius); + } + + public static IndexHnswSearchParam hnsw(int k, int ef) { if (ef < k) { throw new IllegalArgumentException("Minimal value of 'ef' must be greater than or equal to 'k'"); } - return new IndexHnswSearchParam(k, ef); + return new IndexHnswSearchParam(k(k), ef); + } + + public static IndexHnswSearchParam hnsw(@NonNull BaseKnnSearchParam base, int ef) { + if (base.getK() != null && ef < base.getK()) { + throw new IllegalArgumentException("Minimal value of 'ef' must be greater than or equal to 'k'"); + } + return new IndexHnswSearchParam(base, ef); } public static IndexBfSearchParam bf(int k) { - checkK(k); - return new IndexBfSearchParam(k); + return new IndexBfSearchParam(k(k)); + } + + public static IndexBfSearchParam bf(@NonNull BaseKnnSearchParam base) { + return new IndexBfSearchParam(base); } public static IndexIvfSearchParam ivf(int k, int nProbe) { - checkK(k); if (nProbe <= 0) { throw new IllegalArgumentException("Minimal value of 'nProbe' must be greater than 0"); } - return new IndexIvfSearchParam(k, nProbe); + return new IndexIvfSearchParam(k(k), nProbe); + } + + public static IndexIvfSearchParam ivf(@NonNull BaseKnnSearchParam base, int nProbe) { + if (nProbe <= 0) { + throw new IllegalArgumentException("Minimal value of 'nProbe' must be greater than 0"); + } + return new IndexIvfSearchParam(base, nProbe); } private static void checkK(int k) { diff --git a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java index 63fc0d8..9c2558c 100644 --- a/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java +++ b/src/main/java/ru/rt/restream/reindexer/vector/params/KnnSearchParam.java @@ -23,13 +23,6 @@ * Common interface for KNN search parameters. */ public interface KnnSearchParam { - /** - * K - the maximum number of documents returned from the index for subsequent filtering. - * - *

Only required parameter for all vector index types. - */ - int getK(); - /** * Utility method for serializing KNN parameters to CJSON avoiding switch. */ diff --git a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java index 26ee895..cd6631a 100644 --- a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java +++ b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorBfTest.java @@ -50,10 +50,15 @@ public abstract class FloatVectorBfTest extends DbBaseTest { private final String namespaceName = "items"; private Namespace vectorNs; + private List testItems; @BeforeEach public void setUp() { vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class); + testItems = getTestVectorItems(); + for (VectorItem item : testItems) { + db.insert(namespaceName, item); + } } @Test @@ -91,16 +96,11 @@ public void testInsertWithWrongVectorSize_isException() { } @Test - public void testSearchWithBaseParams_isOk() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } - + public void testSearchWithBaseParamK_isOk() { List list = db.query(namespaceName, VectorItem.class) .selectAllFields() .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, - KnnParams.base(2)) + KnnParams.k(2)) .toList(); assertThat(list.size(), is(2)); @@ -111,12 +111,51 @@ public void testSearchWithBaseParams_isOk() { } @Test - public void testSearchWithVecBfParams_isOk() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } + public void testSearchWithBaseParamRadius_isOk() { + List list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, + KnnParams.radius(0.7f)) + .toList(); + + assertThat(list.size(), is(4)); + assertThat(list.get(0).getId(), is(18)); + assertThat(list.get(1).getId(), is(6)); + assertThat(list.get(2).getId(), is(7)); + assertThat(list.get(3).getId(), is(8)); + } + @Test + public void testSearchWithBaseParamsKAndRadius_isOk() { + List list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, + KnnParams.base(3, 0.7f)) + .toList(); + + // by k (3 records) + by radius (4 records) = 3 records + assertThat(list.size(), is(3)); + assertThat(list.get(0).getId(), is(18)); + assertThat(list.get(1).getId(), is(6)); + assertThat(list.get(2).getId(), is(7)); + + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, + KnnParams.base(5, 0.7f)) + .toList(); + + // by k (5 records) + by radius (4 records) = 4 records + assertThat(list.size(), is(4)); + assertThat(list.get(0).getId(), is(18)); + assertThat(list.get(1).getId(), is(6)); + assertThat(list.get(2).getId(), is(7)); + assertThat(list.get(3).getId(), is(8)); + } + + @Test + public void testSearchWithVecBfParams_isOk() { + // only k - 3 records List list = db.query(namespaceName, VectorItem.class) .selectAllFields() .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, @@ -130,15 +169,42 @@ public void testSearchWithVecBfParams_isOk() { assertThat(list.get(1).getVector(), is(testItems.get(18).getVector())); assertThat(list.get(2).getId(), is(19)); assertThat(list.get(2).getVector(), is(testItems.get(19).getVector())); + + // only radius 0.7 - 5 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, + KnnParams.bf(KnnParams.radius(0.7f))) + .toList(); + + assertThat(list.size(), is(5)); + assertThat(list.get(0).getId(), is(8)); + assertThat(list.get(1).getId(), is(18)); + assertThat(list.get(2).getId(), is(19)); + assertThat(list.get(3).getId(), is(1)); + assertThat(list.get(4).getId(), is(2)); + + // by k (3 records) + by radius (5 records) = 3 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, + KnnParams.bf(KnnParams.base(3, 0.7f))) + .toList(); + + assertThat(list.size(), is(3)); + + // by k (6 records) + by radius (5 records) = 5 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, + KnnParams.bf(KnnParams.base(6, 0.7f))) + .toList(); + + assertThat(list.size(), is(5)); } @Test public void testSearchWithIncorrectVecBfParams_isException() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } - assertThrows(IllegalArgumentException.class, () -> db.query(namespaceName, VectorItem.class) .selectAllFields() @@ -149,12 +215,17 @@ public void testSearchWithIncorrectVecBfParams_isException() { } @Test - public void testSearchWithNotVecBfNorBaseParams_isException() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } + public void testTrySearchWithVecBfParamFromNullBaseParam_isException() { + assertThrows(NullPointerException.class, + () -> db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f}, + KnnParams.bf(null)) + .toList()); + } + @Test + public void testSearchWithNotVecBfNorBaseParams_isException() { assertThrows(RuntimeException.class, () -> db.query(namespaceName, VectorItem.class) .selectAllFields() diff --git a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java index d4e76de..486371c 100644 --- a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java +++ b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorHnswTest.java @@ -50,10 +50,15 @@ public abstract class FloatVectorHnswTest extends DbBaseTest { private final String namespaceName = "items"; private Namespace vectorNs; + private List testItems; @BeforeEach public void setUp() { vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class); + testItems = getTestVectorItems(); + for (VectorItem item : testItems) { + db.insert(namespaceName, item); + } } @Test @@ -91,16 +96,11 @@ public void testInsertWithWrongVectorSize_isException() { } @Test - public void testSearchWithBaseParams_isOk() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } - + public void testSearchWithBaseParamK_isOk() { List list = db.query(namespaceName, VectorItem.class) .selectAllFields() .whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f}, - KnnParams.base(2)) + KnnParams.k(2)) .toList(); assertThat(list.size(), is(2)); @@ -111,16 +111,55 @@ public void testSearchWithBaseParams_isOk() { } @Test - public void testSearchWithHnswParams_isOk() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } + public void testSearchWithBaseParamRadius_isOk() { + List list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f}, + KnnParams.radius(0.4f)) + .toList(); + + assertThat(list.size(), is(4)); + assertThat(list.get(0).getId(), is(1)); + assertThat(list.get(1).getId(), is(2)); + assertThat(list.get(2).getId(), is(0)); + assertThat(list.get(3).getId(), is(3)); + } + @Test + public void testSearchWithBaseParamsKAndRadius_isOk() { + // by k (3 records) + by radius (4 records) = 3 records + List list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f}, + KnnParams.base(3, 0.4f)) + .toList(); + + assertThat(list.size(), is(3)); + assertThat(list.get(0).getId(), is(1)); + assertThat(list.get(1).getId(), is(2)); + assertThat(list.get(2).getId(), is(0)); + + // by k (5 records) + by radius (4 records) = 4 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f, 0.13f}, + KnnParams.base(5, 0.4f)) + .toList(); + + assertThat(list.size(), is(4)); + assertThat(list.get(0).getId(), is(1)); + assertThat(list.get(1).getId(), is(2)); + assertThat(list.get(2).getId(), is(0)); + assertThat(list.get(3).getId(), is(3)); + } + + @Test + public void testSearchWithHnswParams_isOk() { + // k - 2 records List list = db.query(namespaceName, VectorItem.class) .selectAllFields() .whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f}, - KnnParams.hnsw(2, 2)) + KnnParams.hnsw(2, 5)) .toList(); assertThat(list.size(), is(2)); @@ -128,15 +167,47 @@ public void testSearchWithHnswParams_isOk() { assertThat(list.get(0).getVector(), is(testItems.get(2).getVector())); assertThat(list.get(1).getId(), is(3)); assertThat(list.get(1).getVector(), is(testItems.get(3).getVector())); + + // radius 0.4 - 4 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f}, + KnnParams.hnsw(KnnParams.radius(0.4f), 5)) + .toList(); + + assertThat(list.size(), is(4)); + + // by k (3 records) + by radius (4 records) = 3 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f}, + KnnParams.hnsw(KnnParams.base(3, 0.4f), 5)) + .toList(); + + assertThat(list.size(), is(3)); + + // by k (5 records) + by radius (4 records) = 4 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f, 0.23f}, + KnnParams.hnsw(KnnParams.base(5, 0.4f), 5)) + .toList(); + + assertThat(list.size(), is(4)); } @Test - public void testSearchWithIncorrectHnswParams_isException() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } + public void testTrySearchWithHnswParamFromNullBaseParam_isException() { + assertThrows(NullPointerException.class, + () -> db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f}, + KnnParams.hnsw(null, 5)) + .toList()); + } + @Test + public void testSearchWithIncorrectHnswParams_isException() { assertThrows(IllegalArgumentException.class, () -> db.query(namespaceName, VectorItem.class) .selectAllFields() @@ -156,11 +227,6 @@ public void testSearchWithIncorrectHnswParams_isException() { @Test public void testSearchWithNotHnswNorBaseParams_isException() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } - assertThrows(RuntimeException.class, () -> db.query(namespaceName, VectorItem.class) .selectAllFields() diff --git a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java index b6e6cfa..9d36acb 100644 --- a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java +++ b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java @@ -41,6 +41,7 @@ import static org.hamcrest.Matchers.notNullValue; import static org.junit.jupiter.api.Assertions.assertThrows; import static ru.rt.restream.reindexer.Query.Condition.EQ; +import static ru.rt.restream.reindexer.vector.params.KnnParams.radius; import static ru.rt.restream.util.ReindexerUtils.getIndexByName; /** @@ -50,10 +51,15 @@ public abstract class FloatVectorIvfTest extends DbBaseTest { private final String namespaceName = "items"; private Namespace vectorNs; + private List testItems; @BeforeEach public void setUp() { vectorNs = db.openNamespace(namespaceName, NamespaceOptions.defaultOptions(), VectorItem.class); + testItems = getTestVectorItems(); + for (VectorItem item : testItems) { + db.insert(namespaceName, item); + } } @Test @@ -91,16 +97,11 @@ public void testInsertWithWrongVectorSize_isException() { } @Test - public void testSearchWithBaseParams_isOk() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } - + public void testSearchWithBaseParamK_isOk() { List list = db.query(namespaceName, VectorItem.class) .selectAllFields() .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, - KnnParams.base(2)) + KnnParams.k(2)) .toList(); assertThat(list.size(), is(2)); @@ -111,34 +112,106 @@ public void testSearchWithBaseParams_isOk() { } @Test - public void testSearchWithIvfParams_isOk() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } + public void testSearchWithBaseParamRadius_isOk() { + List list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, + radius(0.1f)) + .toList(); + assertThat(list.size(), is(4)); + assertThat(list.get(0).getId(), is(18)); + assertThat(list.get(1).getId(), is(6)); + assertThat(list.get(2).getId(), is(7)); + assertThat(list.get(3).getId(), is(8)); + } + + @Test + public void testSearchWithBaseParamsKAndRadius_isOk() { + // by k (3 records) + by radius (4 records) = 3 records List list = db.query(namespaceName, VectorItem.class) .selectAllFields() - .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, - KnnParams.ivf(3, 3)) + .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, + KnnParams.base(3, 0.1f)) .toList(); assertThat(list.size(), is(3)); + assertThat(list.get(0).getId(), is(18)); + assertThat(list.get(1).getId(), is(7)); + assertThat(list.get(2).getId(), is(8)); + + // by k (5 records) + by radius (4 records) = 4 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, + KnnParams.base(5, 0.1f)) + .toList(); + + assertThat(list.size(), is(4)); + assertThat(list.get(0).getId(), is(18)); + assertThat(list.get(1).getId(), is(6)); + assertThat(list.get(2).getId(), is(7)); + assertThat(list.get(3).getId(), is(8)); + } + + @Test + public void testSearchWithIvfParams_isOk() { + // only k - 2 records + List list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, + KnnParams.ivf(2, 3)) + .toList(); + + assertThat(list.size(), is(2)); assertThat(list.get(0).getId(), is(8)); assertThat(list.get(0).getVector(), is(testItems.get(8).getVector())); assertThat(list.get(1).getId(), is(18)); assertThat(list.get(1).getVector(), is(testItems.get(18).getVector())); + + // only radius - 3 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, + KnnParams.ivf(KnnParams.radius(0.3f), 3)) + .toList(); + + assertThat(list.size(), is(3)); + assertThat(list.get(0).getId(), is(8)); + assertThat(list.get(1).getId(), is(18)); assertThat(list.get(2).getId(), is(19)); - assertThat(list.get(2).getVector(), is(testItems.get(19).getVector())); + + // by k (2 records) + by radius (3 records) = 2 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, + KnnParams.ivf(KnnParams.base(2, 0.3f), 3)) + .toList(); + + assertThat(list.size(), is(2)); + + // by k (4 records) + by radius (3 records) = 3 records + list = db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, + KnnParams.ivf(KnnParams.base(4, 0.3f), 3)) + .toList(); + + assertThat(list.size(), is(3)); } @Test - public void testSearchWithIncorrectIvfParams_isException() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } + public void testTrySearchWithIvfParamFromNullBaseParam_isException() { + assertThrows(NullPointerException.class, + () -> db.query(namespaceName, VectorItem.class) + .selectAllFields() + .whereKnn("vector", new float[]{0.5f, 0.0f, 0.6f}, + KnnParams.bf(null)) + .toList()); + } + @Test + public void testSearchWithIncorrectIvfParams_isException() { assertThrows(IllegalArgumentException.class, () -> db.query(namespaceName, VectorItem.class) .selectAllFields() @@ -158,11 +231,6 @@ public void testSearchWithIncorrectIvfParams_isException() { @Test public void testSearchWithNotIvfNorBaseParams_isException() { - List testItems = getTestVectorItems(); - for (VectorItem item : testItems) { - db.insert(namespaceName, item); - } - assertThrows(RuntimeException.class, () -> db.query(namespaceName, VectorItem.class) .selectAllFields()