diff --git a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java index 43703cfe75f3..eecf6bde2761 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java @@ -73,6 +73,16 @@ public SSTableId getSourceSstableId() return sourceSstableId; } + @Override + public PrimaryKeyWithSource forStaticRow() + { + return new PrimaryKeyWithSource(primaryKey().forStaticRow(), + sourceSstableId, + sourceRowId, + sourceSstableMinKey, + sourceSstableMaxKey); + } + @Override public Token token() { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java index c1b47916229b..80eac7ce0cb8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyFactory.java @@ -80,6 +80,12 @@ public PrimaryKey loadDeferred() return this; } + @Override + public PartitionAwarePrimaryKey forStaticRow() + { + return this; + } + @Override public Token token() { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java index cba1e1124ce3..1664c51634ac 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java @@ -82,6 +82,12 @@ private RowAwarePrimaryKey(Token token, DecoratedKey partitionKey, Clustering cl this.primaryKeySupplier = primaryKeySupplier; } + @Override + public RowAwarePrimaryKey forStaticRow() + { + return new RowAwarePrimaryKey(token, partitionKey, Clustering.STATIC_CLUSTERING, primaryKeySupplier); + } + @Override public Token token() { @@ -220,12 +226,12 @@ public long ramBytesUsed() // Object header + 4 references (token, partitionKey, clustering, primaryKeySupplier) + implicit outer reference long size = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + 5L * RamUsageEstimator.NUM_BYTES_OBJECT_REF; - + if (token != null) size += token.getHeapSize(); if (partitionKey != null) - size += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + - 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + // token and key references + size += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + 2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + // token and key references 2L * Long.BYTES; // We don't count clustering size here as it's managed elsewhere return size; diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index 3082248a7039..cf540b7caccd 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -314,9 +314,15 @@ public CloseableIterator orderResultsBy(QueryContext cont var relevantOrdinals = new IntHashSet(); var keysInRange = PrimaryKeyListUtil.getKeysInRange(keys, minimumKey, maximumKey); + boolean isStatic = indexContext.getDefinition().isStatic(); keysInRange.forEach(k -> { + // if the indexed column is static, we need to get the static row associated with the non-static row that + // might be referenced by the key + if (isStatic) + k = k.forStaticRow(); + var v = graph.vectorForKey(k); if (v == null) return; diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 8756acaac5cc..8680b076692b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -339,6 +339,12 @@ public UnfilteredRowIterator getPartition(PrimaryKey key, ColumnFamilyStore.View // Class to transform the row to include its source table. Function>> rowTransformer = (Object sourceTable) -> new Transformation<>() { + @Override + protected Row applyToStatic(Row row) + { + return new RowWithSourceTable(row, sourceTable); + } + @Override protected Row applyToRow(Row row) { diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index afc524ee6c6f..38c84db2f2df 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -698,18 +698,35 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List + * On a replica: + *
    + *
  • filter(ScoreOrderedResultRetriever) is used to collect up to the top-K rows.
  • + *
  • We store any tombstones as well, to avoid losing them during coordinator reconciliation.
  • + *
  • The result is returned in PK order so that coordinator can merge from multiple replicas.
  • + *
* On a coordinator: - * - reorder(PartitionIterator) is used to consume all rows from the provided partitions, - * compute the order based on either a column ordering or a similarity score, and keep top-K. - * - The result is returned in score/sortkey order. + *
    + *
  • reorder(PartitionIterator) is used to consume all rows from the provided partitions, + * compute the order based on either a column ordering or a similarity score, and keep top-K.
  • + *
  • The result is returned in score/sortkey order.
  • + *
*/ public class TopKProcessor { @@ -104,7 +109,7 @@ public TopKProcessor(ReadCommand command) this.indexContext = indexAndExpression.left; this.expression = indexAndExpression.right; - if (expression.operator() == Operator.ANN && !Ordering.Ann.useSyntheticScore()) + if (expression.operator() == Operator.ANN && !(Ordering.Ann.useSyntheticScore() && expression.column().isRegular())) this.queryVector = vts.createFloatVector(TypeUtil.decomposeVector(indexContext, expression.getIndexValue().duplicate())); else this.queryVector = null; @@ -253,6 +258,9 @@ private PartitionResults processScoredPartition(BaseRowIterator partitionRowI float keyAndStaticScore = getScoreForRow(key, staticRow); var pr = new PartitionResults(partitionInfo); + if (!partitionRowIterator.hasNext()) + pr.addRow(Triple.of(partitionInfo, BTreeRow.emptyRow(Clustering.EMPTY), keyAndStaticScore)); + while (partitionRowIterator.hasNext()) { Unfiltered unfiltered = partitionRowIterator.next(); @@ -278,16 +286,13 @@ private PartitionResults processScoredPartition(BaseRowIterator partitionRowI private int processSingleRowPartition(TreeMap> unfilteredByPartition, BaseRowIterator partitionRowIterator) { - if (!partitionRowIterator.hasNext()) - return 0; - - Unfiltered unfiltered = partitionRowIterator.next(); + Unfiltered unfiltered = partitionRowIterator.hasNext() ? partitionRowIterator.next() : null; assert !partitionRowIterator.hasNext() : "Only one row should be returned"; // Always include tombstones for coordinator. It relies on ReadCommand#withMetricsRecording to throw // TombstoneOverwhelmingException to prevent OOM. PartitionInfo partitionInfo = PartitionInfo.create(partitionRowIterator); addUnfiltered(unfilteredByPartition, partitionInfo, unfiltered); - return unfiltered.isRangeTombstoneMarker() ? 0 : 1; + return unfiltered != null && unfiltered.isRangeTombstoneMarker() ? 0 : 1; } private void addUnfiltered(SortedMap> unfilteredByPartition, @@ -295,14 +300,15 @@ private void addUnfiltered(SortedMap> unfilte Unfiltered unfiltered) { var map = unfilteredByPartition.computeIfAbsent(partitionInfo, k -> new TreeSet<>(command.metadata().comparator)); - map.add(unfiltered); + if (unfiltered != null) + map.add(unfiltered); } private float getScoreForRow(DecoratedKey key, Row row) { ColumnMetadata column = indexContext.getDefinition(); - if (column.isPrimaryKeyColumn() && key == null) + if (column.isPartitionKey() && key == null) return 0; if (column.isStatic() && !row.isStatic()) diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java index 0bcd16c0a7f3..2bf5a4f0af00 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java @@ -111,6 +111,13 @@ static Factory factory(ClusteringComparator clusteringComparator, IndexFeatureSe : new PartitionAwarePrimaryKeyFactory(); } + /** + * Returns a {@link PrimaryKey} to fetch the static row of the partition associated with this primary key. + * + * @return a {@link PrimaryKey} for the static row + */ + PrimaryKey forStaticRow(); + /** * Returns the {@link Token} associated with this primary key. * diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java index bb931c942fbf..480cba52937b 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java @@ -39,16 +39,26 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) { - super(context, sourceTable, primaryKey); - this.byteComparable = byteComparable; + this(context, (Object) sourceTable, primaryKey, byteComparable); + } + + public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + { + this(context, (Object) sourceTable, primaryKey, byteComparable); } - public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + private PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) { super(context, sourceTable, primaryKey); this.byteComparable = byteComparable; } + @Override + public PrimaryKeyWithByteComparable forStaticRow() + { + return new PrimaryKeyWithByteComparable(context, sourceTable, primaryKey.forStaticRow(), byteComparable); + } + @Override protected boolean isIndexDataEqualToLiveData(ByteBuffer value) { diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java index b88c210d65f2..572851420b93 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java @@ -34,16 +34,26 @@ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore) { - super(context, source, primaryKey); - this.indexScore = indexScore; + this(context, (Object) source, primaryKey, indexScore); + } + + public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore) + { + this(context, (Object) source, primaryKey, indexScore); } - public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore) + private PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore) { super(context, source, primaryKey); this.indexScore = indexScore; } + @Override + public PrimaryKeyWithScore forStaticRow() + { + return new PrimaryKeyWithScore(context, sourceTable, primaryKey.forStaticRow(), indexScore); + } + @Override protected boolean isIndexDataEqualToLiveData(ByteBuffer value) { diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java index b3a6fb4338e5..97ba2f379269 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java @@ -28,6 +28,7 @@ import org.apache.cassandra.dht.Token; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; @@ -40,19 +41,13 @@ public abstract class PrimaryKeyWithSortKey implements PrimaryKey { protected final IndexContext context; - private final PrimaryKey primaryKey; + protected final PrimaryKey primaryKey; // Either a Memtable reference or an SSTableId reference - private final Object sourceTable; + protected final Object sourceTable; - protected PrimaryKeyWithSortKey(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey) - { - this.context = context; - this.sourceTable = sourceTable; - this.primaryKey = primaryKey; - } - - protected PrimaryKeyWithSortKey(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey) + protected PrimaryKeyWithSortKey(IndexContext context, Object sourceTable, PrimaryKey primaryKey) { + assert sourceTable instanceof Memtable || sourceTable instanceof SSTableId; this.context = context; this.sourceTable = sourceTable; this.primaryKey = primaryKey; @@ -65,10 +60,18 @@ public PrimaryKey primaryKey() public boolean isIndexDataValid(Row row, int nowInSecs) { - assert context.getDefinition().isRegular() : "Only regular columns are supported, got " + context.getDefinition(); - var cell = row.getCell(context.getDefinition()); + ColumnMetadata column = context.getDefinition(); + + if (row.isStatic() && !column.isStatic()) + return true; + + if (!row.isStatic() && !column.isRegular()) + return true; + + var cell = row.getCell(column); if (!cell.isLive(nowInSecs)) return false; + assert cell instanceof CellWithSourceTable : "Expected CellWithSource, got " + cell.getClass(); return sourceTable.equals(((CellWithSourceTable) cell).sourceTable()) && isIndexDataEqualToLiveData(cell.buffer()); @@ -103,7 +106,6 @@ public final boolean equals(Object obj) return primaryKey.equals(((PrimaryKeyWithSortKey) obj).primaryKey()); } - // Generic primary key wrapper methods: @Override public Token token() diff --git a/test/unit/org/apache/cassandra/index/sai/cql/PlanWithIndexHintsTest.java b/test/unit/org/apache/cassandra/index/sai/cql/PlanWithIndexHintsTest.java index 0f8b8a1a23e1..73a40c3da11a 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/PlanWithIndexHintsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/PlanWithIndexHintsTest.java @@ -20,7 +20,6 @@ import org.apache.cassandra.net.MessagingService; import org.junit.Assume; -import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; @@ -60,12 +59,6 @@ public static void setUpClass() SAITester.setUpClass(); } - @Before - public void setup() throws Throwable - { - CassandraRelevantProperties.DS_CURRENT_MESSAGING_VERSION.setInt(MessagingService.VERSION_DS_12); - } - @Test public void testQueryPlanning() throws Throwable { @@ -630,16 +623,16 @@ public void columnPositionsTest() // insert some data String insert = "INSERT INTO %s (k1, k2, k3, c1, c2, c3, s1, s2, s3, r1, r2, r3) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; - Object[] row1 = row(1, "a", vector(0.1f, 0.1f), 10, "aa", vector(0.01f, 0.01f), 100, "aaa", vector(0.001f, 0.001f), 1000, "aaaa", vector(0.0001f, 0.0001f)); - Object[] row2 = row(1, "a", vector(0.1f, 0.2f), 11, "ab", vector(0.01f, 0.02f), 100, "aaa", vector(0.001f, 0.002f), 1001, "aaab", vector(0.0001f, 0.0002f)); - Object[] row3 = row(2, "b", vector(0.2f, 0.1f), 10, "aa", vector(0.02f, 0.01f), 200, "bbb", vector(0.002f, 0.001f), 1000, "aaaa", vector(0.0002f, 0.0001f)); - Object[] row4 = row(2, "b", vector(0.2f, 0.2f), 11, "ab", vector(0.02f, 0.02f), 200, "bbb", vector(0.002f, 0.002f), 1001, "aaab", vector(0.0002f, 0.0002f)); - execute(insert, row1); - execute(insert, row2); - execute(insert, row3); - execute(insert, row4); + execute(insert, row(1, "a", vector(0.1f, 0.1f), 10, "aa", vector(0.01f, 0.01f), 100, "aaa", vector(0.001f, 0.001f), 1000, "aaaa", vector(0.0001f, 0.0001f))); + execute(insert, row(1, "a", vector(0.1f, 0.2f), 11, "ab", vector(0.01f, 0.02f), 100, "aaa", vector(0.001f, 0.002f), 1001, "aaab", vector(0.0001f, 0.0002f))); + execute(insert, row(2, "b", vector(0.2f, 0.1f), 10, "aa", vector(0.02f, 0.01f), 200, "bbb", vector(0.002f, 0.001f), 1000, "aaaa", vector(0.0002f, 0.0001f))); + execute(insert, row(2, "b", vector(0.2f, 0.2f), 11, "ab", vector(0.02f, 0.02f), 200, "bbb", vector(0.002f, 0.002f), 1001, "aaab", vector(0.0002f, 0.0002f))); + Object[] row1 = row(1, "a", 10, "aa"); + Object[] row2 = row(1, "a", 11, "ab"); + Object[] row3 = row(2, "b", 10, "aa"); + Object[] row4 = row(2, "b", 11, "ab"); - String query = "SELECT k1, k2, k3, c1, c2, c3, s1, s2, s3, r1, r2, r3 FROM %s "; + final String query = "SELECT k1, k2, c1, c2 FROM %s "; // query partition key columns without hints assertThatPlanFor(query + "WHERE k1=1", row1, row2).uses("partition_numeric"); @@ -652,9 +645,8 @@ public void columnPositionsTest() assertThatPlanFor(query + "WHERE k2:'a'", row1, row2).uses("partition_analyzed"); if (supportsBM25) assertBM25OnNonRegularColumnIsRejected(query + "ORDER BY k2 BM25 OF 'a' LIMIT 10", ColumnMetadata.Kind.PARTITION_KEY, "k2"); - // TODO: this hits CNDB-14343, we should either enable this or remove the index creation when that is resolved - // if (supportsANN) - // assertThatPlanFor(query + "ORDER BY k3 ANN OF [0.1, 0.2] LIMIT 10", row1, row2).uses("partition_ann"); + if (supportsANN) + assertThatPlanFor(query + "ORDER BY k3 ANN OF [0.1, 0.1] LIMIT 10", row1, row2, row3, row4).uses("partition_ann"); // query partition key columns with included index assertThatPlanFor(query + "WHERE k1=1 WITH included_indexes={partition_numeric}", row1, row2).uses("partition_numeric"); @@ -668,9 +660,8 @@ public void columnPositionsTest() if (supportsBM25) assertBM25OnNonRegularColumnIsRejected(query + "ORDER BY k2 BM25 OF 'a' LIMIT 10 WITH included_indexes={partition_analyzed}", ColumnMetadata.Kind.PARTITION_KEY, "k2"); - // TODO: this hits CNDB-14343, we should either enable this or remove the index creation when that is resolved - // if (supportsANN) - // assertThatPlanFor(query + "ORDER BY k3 ANN OF [0.1, 0.2] LIMIT 10 WITH included_indexes={partition_ann}", row1, row2).uses("partition_ann"); + if (supportsANN) + assertThatPlanFor(query + "ORDER BY k3 ANN OF [0.1, 0.1] LIMIT 10 WITH included_indexes={partition_ann}", row1, row2, row3, row4).uses("partition_ann"); // query partition key columns with excluded index assertThatPlanFor(query + "WHERE k1=1 ALLOW FILTERING WITH excluded_indexes={partition_numeric}", row1, row2).usesNone(); @@ -684,7 +675,7 @@ public void columnPositionsTest() if (supportsBM25) assertBM25RequiresAnAnalyzedIndex(query + "ORDER BY k2 BM25 OF 'a' LIMIT 10 WITH excluded_indexes={partition_analyzed}", "k2"); if (supportsANN) - assertANNOrderingNeedsIndex(query + "ORDER BY k3 ANN OF [0.1, 0.2] LIMIT 10 WITH excluded_indexes={partition_ann}", "k3"); + assertANNOrderingNeedsIndex(query + "ORDER BY k3 ANN OF [0.1, 0.1] LIMIT 10 WITH excluded_indexes={partition_ann}", "k3"); // query clustering key columns without hints assertThatPlanFor(query + "WHERE c1=10", row1, row3).uses("clustering_numeric"); @@ -697,9 +688,8 @@ public void columnPositionsTest() assertCannotBeRestrictedByClustering(query + "WHERE c2:'aa'", "c2"); if (supportsBM25) assertBM25OnNonRegularColumnIsRejected(query + "ORDER BY c2 BM25 OF 'aa' LIMIT 10", ColumnMetadata.Kind.CLUSTERING, "c2"); - // TODO: this hits CNDB-14343, we should either enable this or remove the index creation when that is resolved - // if (supportsANN) - // assertThatPlanFor(query + "ORDER BY c3 ANN OF [0.1, 0.2] LIMIT 10", row1, row2).uses("clustering_ann"); + if (supportsANN) + assertThatPlanFor(query + "ORDER BY c3 ANN OF [0.1, 0.1] LIMIT 10", row1, row2, row3, row4).uses("clustering_ann"); // query clustering key columns with included index assertThatPlanFor(query + "WHERE c1=10 WITH included_indexes={clustering_numeric}", row1, row3).uses("clustering_numeric"); @@ -713,9 +703,8 @@ public void columnPositionsTest() if (supportsBM25) assertBM25OnNonRegularColumnIsRejected(query + "ORDER BY c2 BM25 OF 'aa' LIMIT 10 WITH included_indexes={clustering_analyzed}", ColumnMetadata.Kind.CLUSTERING, "c2"); - // TODO: this hits CNDB-14343, we should either enable this or remove the index creation when that is resolved - // if (supportsANN) - // assertThatPlanFor(query + "ORDER BY c3 ANN OF [0.1, 0.2] LIMIT 10 WITH included_indexes={clustering_ann}", row1, row2).uses("clustering_ann"); + if (supportsANN) + assertThatPlanFor(query + "ORDER BY c3 ANN OF [0.1, 0.1] LIMIT 10 WITH included_indexes={clustering_ann}", row1, row2, row3, row4).uses("clustering_ann"); // query clustering key columns with excluded index assertThatPlanFor(query + "WHERE c1=10 ALLOW FILTERING WITH excluded_indexes={clustering_numeric}", row1, row3).usesNone(); @@ -742,9 +731,8 @@ public void columnPositionsTest() assertThatPlanFor(query + "WHERE s2:'aaa'", row1, row2).uses("static_analyzed"); if (supportsBM25) assertBM25OnNonRegularColumnIsRejected(query + "ORDER BY s2 BM25 OF 'aa' LIMIT 10", ColumnMetadata.Kind.STATIC, "s2"); - // TODO: this hits CNDB-14343, we should either enable this or remove the index creation when that is resolved - // if (supportsANN) - // assertThatPlanFor(query + "ORDER BY s3 ANN OF [0.1, 0.2] LIMIT 10", row1, row2).uses("static_ann"); + if (supportsANN) + assertThatPlanFor(query + "ORDER BY s3 ANN OF [0.1, 0.1] LIMIT 10", row1, row2, row3, row4).uses("static_ann"); // query static columns with included index assertThatPlanFor(query + "WHERE s1=100 WITH included_indexes={static_numeric}", row1, row2).uses("static_numeric"); @@ -758,9 +746,8 @@ public void columnPositionsTest() if (supportsBM25) assertBM25OnNonRegularColumnIsRejected(query + "ORDER BY s2 BM25 OF 'aa' LIMIT 10 WITH included_indexes={static_analyzed}", ColumnMetadata.Kind.STATIC, "s2"); - // TODO: this hits CNDB-14343, we should either enable this or remove the index creation when that is resolved - // if (supportsANN) - // assertThatPlanFor(query + "ORDER BY s3 ANN OF [0.1, 0.2] LIMIT 10 WITH included_indexes={static_ann}", row1, row2).uses("static_ann"); + if (supportsANN) + assertThatPlanFor(query + "ORDER BY s3 ANN OF [0.1, 0.1] LIMIT 10 WITH included_indexes={static_ann}", row1, row2, row3, row4).uses("static_ann"); // query static columns with excluded index assertThatPlanFor(query + "WHERE s1=100 ALLOW FILTERING WITH excluded_indexes={static_numeric}", row1, row2).usesNone(); @@ -774,7 +761,7 @@ public void columnPositionsTest() if (supportsBM25) assertBM25RequiresAnAnalyzedIndex(query + "ORDER BY s2 BM25 OF 'aa' LIMIT 10 WITH excluded_indexes={static_analyzed}", "s2"); if (supportsANN) - assertANNOrderingNeedsIndex(query + "ORDER BY s3 ANN OF [0.1, 0.2] LIMIT 10 WITH excluded_indexes={static_ann}", "s3"); + assertANNOrderingNeedsIndex(query + "ORDER BY s3 ANN OF [0.1, 0.1] LIMIT 10 WITH excluded_indexes={static_ann}", "s3"); // query regular columns without hints assertThatPlanFor(query + "WHERE r1=1000", row1, row3).uses("regular_numeric"); @@ -788,7 +775,7 @@ public void columnPositionsTest() if (supportsBM25) assertThatPlanFor(query + "ORDER BY r2 BM25 OF 'aaaa' LIMIT 10", row1, row3).uses("regular_analyzed"); if (supportsANN) - assertThatPlanFor(query + "ORDER BY r3 ANN OF [0.1, 0.2] LIMIT 10", row1, row2, row3, row4).uses("regular_ann"); + assertThatPlanFor(query + "ORDER BY r3 ANN OF [0.1, 0.1] LIMIT 10", row1, row2, row3, row4).uses("regular_ann"); // query regular columns with included indexes assertThatPlanFor(query + "WHERE r1=1000 WITH included_indexes={regular_numeric}", row1, row3).uses("regular_numeric"); @@ -802,7 +789,7 @@ public void columnPositionsTest() if (supportsBM25) assertThatPlanFor(query + "ORDER BY r2 BM25 OF 'aaaa' LIMIT 10 WITH included_indexes={regular_analyzed}", row1, row3).uses("regular_analyzed"); if (supportsANN) - assertThatPlanFor(query + "ORDER BY r3 ANN OF [0.1, 0.2] LIMIT 10 WITH included_indexes={regular_ann}", row1, row2, row3, row4).uses("regular_ann"); + assertThatPlanFor(query + "ORDER BY r3 ANN OF [0.1, 0.1] LIMIT 10 WITH included_indexes={regular_ann}", row1, row2, row3, row4).uses("regular_ann"); // query regular columns with excluded indexes assertThatPlanFor(query + "WHERE r1=1000 ALLOW FILTERING WITH excluded_indexes={regular_numeric}", row1, row3).usesNone(); @@ -816,7 +803,7 @@ public void columnPositionsTest() if (supportsBM25) assertBM25RequiresAnAnalyzedIndex(query + "ORDER BY r2 BM25 OF 'aaaa' LIMIT 10 WITH excluded_indexes={regular_analyzed}", "r2"); if (supportsANN) - assertANNOrderingNeedsIndex(query + "ORDER BY r3 ANN OF [0.1, 0.2] LIMIT 10 WITH excluded_indexes={regular_ann}", "r3"); + assertANNOrderingNeedsIndex(query + "ORDER BY r3 ANN OF [0.1, 0.1] LIMIT 10 WITH excluded_indexes={regular_ann}", "r3"); } private void assertNeedsAllowFiltering(String query) diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorColumnPositionsTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorColumnPositionsTest.java new file mode 100644 index 000000000000..fbac8b9de3c5 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorColumnPositionsTest.java @@ -0,0 +1,654 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.index.sai.SAIUtil; + +/** + * Tests vector indexes in different column positions (partition key, clustering key, static and regular columns). + */ +public class VectorColumnPositionsTest extends VectorTester.Versioned +{ + @Before + public void setupVersion() + { + SAIUtil.setCurrentVersion(version); + } + + @Test + public void testPartitionKey() + { + createTable("CREATE TABLE %s (k vector PRIMARY KEY)"); + assertInvalidMessage("Cannot create secondary index on the only partition key column k", + "CREATE CUSTOM INDEX ON %s(k) USING 'StorageAttachedIndex'"); + } + + @Test + public void testPartitionKeyComponentWithColumns() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 vector, v int, PRIMARY KEY((k1, k2)))"); + createIndex("CREATE CUSTOM INDEX ON %s(k2) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + String insert = "INSERT INTO %s (k1, k2, v) VALUES (?, ?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f), 0)); + execute(insert, row(2, vector(0.1f, 0.2f), 0)); + execute(insert, row(3, vector(0.1f, 0.3f), 1)); + execute(insert, row(4, vector(0.1f, 0.4f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + + // query with hybrid search + assertRows(execute("SELECT k1 FROM %s WHERE v>=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k1 FROM %s WHERE v=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT k1 FROM %s WHERE v>0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.4f), 0)); + execute(insert, row(2, vector(0.1f, 0.3f), 0)); + execute(insert, row(3, vector(0.1f, 0.2f), 1)); + execute(insert, row(4, vector(0.1f, 0.1f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + + // query with hybrid search + assertRows(execute("SELECT k1 FROM %s WHERE v>=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k1 FROM %s WHERE v=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT k1 FROM %s WHERE v>0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3)); + }); + } + + @Test + public void testPartitionKeyComponentWithoutColumns() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 vector, v int, PRIMARY KEY((k1, k2)))"); + createIndex("CREATE CUSTOM INDEX ON %s(k2) USING 'StorageAttachedIndex'"); + + String insert = "INSERT INTO %s (k1, k2) VALUES (?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f))); + execute(insert, row(2, vector(0.1f, 0.2f))); + execute(insert, row(3, vector(0.1f, 0.3f))); + execute(insert, row(4, vector(0.1f, 0.4f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.4f))); + execute(insert, row(2, vector(0.1f, 0.3f))); + execute(insert, row(3, vector(0.1f, 0.2f))); + execute(insert, row(4, vector(0.1f, 0.1f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + }); + } + + @Test + public void testPartitionKeyComponentWithClustering() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 vector, c int, v int, PRIMARY KEY((k1, k2), c))"); + createIndex("CREATE CUSTOM INDEX ON %s(k2) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // insert static rows and non-static rows all at once + String insert = "INSERT INTO %s (k1, k2, c, v) VALUES (?, ?, ?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f), 10, 0)); + execute(insert, row(2, vector(0.1f, 0.2f), 20, 0)); + execute(insert, row(3, vector(0.1f, 0.3f), 30, 1)); + execute(insert, row(4, vector(0.1f, 0.4f), 40, 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + + // query with hybrid search + assertRows(execute("SELECT k1 FROM %s WHERE v>=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k1 FROM %s WHERE v=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT k1 FROM %s WHERE v>0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.4f), 10, 0)); + execute(insert, row(2, vector(0.1f, 0.3f), 20, 0)); + execute(insert, row(3, vector(0.1f, 0.2f), 30, 1)); + execute(insert, row(4, vector(0.1f, 0.1f), 40, 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k1 FROM %s ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + + // query with hybrid search + assertRows(execute("SELECT k1 FROM %s WHERE v>=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k1 FROM %s WHERE v=0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT k1 FROM %s WHERE v>0 ORDER BY k2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3)); + }); + } + + @Test + public void testClusteringKey() throws Throwable + { + createTable("CREATE TABLE %s (k int, c vector, v int, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(c) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + String insert = "INSERT INTO %s (k, c, v) VALUES (?, ?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f), 0)); + execute(insert, row(2, vector(0.1f, 0.2f), 0)); + execute(insert, row(3, vector(0.1f, 0.3f), 1)); + execute(insert, row(4, vector(0.1f, 0.4f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY c ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.4f), 0)); + execute(insert, row(2, vector(0.1f, 0.3f), 0)); + execute(insert, row(3, vector(0.1f, 0.2f), 1)); + execute(insert, row(4, vector(0.1f, 0.1f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY c ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY c ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3)); + }); + } + + @Test + public void testClusteringKeyComponent() throws Throwable + { + createTable("CREATE TABLE %s (k int, c1 int, c2 vector, v int, PRIMARY KEY(k, c1, c2))"); + createIndex("CREATE CUSTOM INDEX ON %s(c2) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + String insert = "INSERT INTO %s (k, c1, c2, v) VALUES (?, ?, ?, ?)"; + execute(insert, row(1, 1, vector(0.1f, 0.1f), 0)); + execute(insert, row(1, 2, vector(0.1f, 0.2f), 0)); + execute(insert, row(2, 3, vector(0.1f, 0.3f), 1)); + execute(insert, row(2, 4, vector(0.1f, 0.4f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT c1 FROM %s ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT c1 FROM %s ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + + // query with hybrid search + assertRows(execute("SELECT c1 FROM %s WHERE v>=0 ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT c1 FROM %s WHERE v=0 ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT c1 FROM %s WHERE v>0 ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, 1, vector(0.1f, 0.4f), 0)); + execute(insert, row(1, 2, vector(0.1f, 0.3f), 0)); + execute(insert, row(2, 3, vector(0.1f, 0.2f), 1)); + execute(insert, row(2, 4, vector(0.1f, 0.1f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT c1 FROM %s ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT c1 FROM %s ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + + // query with hybrid search + assertRows(execute("SELECT c1 FROM %s WHERE v>=0 ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT c1 FROM %s WHERE v=0 ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT c1 FROM %s WHERE v>0 ORDER BY c2 ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3)); + }); + } + + @Test + public void testRegularColumnWithoutClustering() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, r vector, v int)"); + createIndex("CREATE CUSTOM INDEX ON %s(r) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + String insert = "INSERT INTO %s (k, r, v) VALUES (?, ?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f), 0)); + execute(insert, row(2, vector(0.1f, 0.2f), 0)); + execute(insert, row(3, vector(0.1f, 0.3f), 1)); + execute(insert, row(4, vector(0.1f, 0.4f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.4f), 0)); + execute(insert, row(2, vector(0.1f, 0.3f), 0)); + execute(insert, row(3, vector(0.1f, 0.2f), 1)); + execute(insert, row(4, vector(0.1f, 0.1f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3)); + }); + } + + @Test + public void testRegularColumnWithClustering() throws Throwable + { + createTable("CREATE TABLE %s (k int, c int, r vector, v int, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(r) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + String insert = "INSERT INTO %s (k, c, r, v) VALUES (?, ?, ?, ?)"; + execute(insert, row(0, 1, vector(0.1f, 0.1f), 0)); + execute(insert, row(1, 2, vector(0.1f, 0.2f), 0)); + execute(insert, row(0, 3, vector(0.1f, 0.3f), 1)); + execute(insert, row(1, 4, vector(0.1f, 0.4f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + + // query with hybrid search + assertRows(execute("SELECT c FROM %s WHERE v>=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT c FROM %s WHERE v=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT c FROM %s WHERE v>0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // update the best vector to make it the worst + execute(insert, row(0, 1, vector(0.1f, 0.5f), 0)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 2"), row(2), row(3)); + + // query with hybrid search + assertRows(execute("SELECT c FROM %s WHERE v>=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + assertRows(execute("SELECT c FROM %s WHERE v=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT c FROM %s WHERE v>0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(0, 1, vector(0.1f, 0.4f), 0)); + execute(insert, row(1, 2, vector(0.1f, 0.3f), 0)); + execute(insert, row(0, 3, vector(0.1f, 0.2f), 1)); + execute(insert, row(1, 4, vector(0.1f, 0.1f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + + // query with hybrid search + assertRows(execute("SELECT c FROM %s WHERE v>=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT c FROM %s WHERE v=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT c FROM %s WHERE v>0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3)); + }); + + // update the best vector to make it the worst + execute(insert, row(1, 4, vector(0.1f, 0.5f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + assertRows(execute("SELECT c FROM %s ORDER BY r ANN OF [0.1, 0.1] LIMIT 2"), row(3), row(2)); + + // query with hybrid search + assertRows(execute("SELECT c FROM %s WHERE v>=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + assertRows(execute("SELECT c FROM %s WHERE v=0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT c FROM %s WHERE v>0 ORDER BY r ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + } + + @Test + public void testStaticColumnWithoutRows() throws Throwable + { + createTable("CREATE TABLE %s (k int, c int, s vector static, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(s) USING 'StorageAttachedIndex'"); + + // insert static rows alone, without non-static rows + String insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f))); + execute(insert, row(2, vector(0.1f, 0.2f))); + execute(insert, row(3, vector(0.1f, 0.3f))); + execute(insert, row(4, vector(0.1f, 0.4f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + }); + + // update the best vector to make it the worst + insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(1, vector(0.1f, 0.5f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(2), row(3)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.4f))); + execute(insert, row(2, vector(0.1f, 0.3f))); + execute(insert, row(3, vector(0.1f, 0.2f))); + execute(insert, row(4, vector(0.1f, 0.1f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + }); + + // update the best vector to make it the worst + insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(4, vector(0.1f, 0.5f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(3), row(2)); + }); + } + + @Test + public void testStaticColumnWithRowsTogether() throws Throwable + { + createTable("CREATE TABLE %s (k int, c int, s vector static, v int, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(s) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // insert static rows and non-static rows all at once + String insert = "INSERT INTO %s (k, c, s, v) VALUES (?, ?, ?, ?)"; + execute(insert, row(1, 10, vector(0.1f, 0.1f), 0)); + execute(insert, row(2, 20, vector(0.1f, 0.2f), 0)); + execute(insert, row(3, 30, vector(0.1f, 0.3f), 1)); + execute(insert, row(4, 40, vector(0.1f, 0.4f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(1), row(2)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // update the best vector to make it the worst + execute(insert, row(1, 10, vector(0.1f, 0.5f), 0)); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + }); + + // update the worst vector to make it the best + execute(insert, row(1, 10, vector(0.1f, 0.1f), 0)); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, 10, vector(0.1f, 0.4f), 0)); + execute(insert, row(2, 20, vector(0.1f, 0.3f), 0)); + execute(insert, row(3, 30, vector(0.1f, 0.2f), 1)); + execute(insert, row(4, 40, vector(0.1f, 0.1f), 1)); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(4), row(3)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3)); + }); + + // update the best vector to make it the worst + execute(insert, row(4, 40, vector(0.1f, 0.5f), 1)); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + }); + + // update the worst vector to make it the best + execute(insert, row(4, 40, vector(0.1f, 0.1f), 1)); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(4), row(3), row(2), row(1)); + }); + } + + @Test + public void testStaticColumnWithRowsSeparate() throws Throwable + { + createTable("CREATE TABLE %s (k int, c int, s vector static, v int, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(s) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // insert static rows alone, without non-static rows + String insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f))); + execute(insert, row(2, vector(0.1f, 0.2f))); + execute(insert, row(3, vector(0.1f, 0.3f))); + execute(insert, row(4, vector(0.1f, 0.4f))); + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10")); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10")); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10")); + }); + + // insert non-static rows, overlapping and non-overlapping with static rows + insert = "INSERT INTO %s (k, c, v) VALUES (?, ?, ?)"; + execute(insert, row(1, 10, 0)); + execute(insert, row(1, 11, 0)); + execute(insert, row(2, 20, 0)); + execute(insert, row(2, 21, 0)); + execute(insert, row(3, 30, 1)); + execute(insert, row(4, 40, 1)); + execute(insert, row(5, 50, 1)); // this one does not have a static row + beforeAndAfterFlush(() -> { + // query with ANN only + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + + // query with hybrid search + assertRows(execute("SELECT k FROM %s WHERE v>=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2)); + assertRows(execute("SELECT k FROM %s WHERE v>0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(4)); + }); + + // update the best vector to make it the worst + insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(1, vector(0.1f, 0.5f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + }); + + // update the worst vector to make it the best + insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(1, vector(0.1f, 0.1f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1), row(2), row(3), row(4)); + }); + } + + @Test + public void testStaticColumnWithDifferentSourcesWithRegulars() throws Throwable + { + createTable("CREATE TABLE %s (k int, c int, s vector static, v int, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(s) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // insert static rows alone, with non-static rows + String insert = "INSERT INTO %s (k, s, c, v) VALUES (?, ?, ?, ?)"; + execute(insert, row(1, vector(0.1f, 0.4f), 1, 0)); + execute(insert, row(2, vector(0.1f, 0.3f), 1, 1)); + execute(insert, row(3, vector(0.1f, 0.2f), 1, 0)); + execute(insert, row(4, vector(0.1f, 0.1f), 1, 1)); + flush(); + + // update the best vector to make it the worst, with non-static rows + execute(insert, row(4, vector(0.1f, 0.5f), 1, 1)); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(3), row(2)); + + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1"), row(3)); + + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1"), row(2)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.1f), 1, 0)); + execute(insert, row(2, vector(0.1f, 0.2f), 1, 1)); + execute(insert, row(3, vector(0.1f, 0.3f), 1, 0)); + execute(insert, row(4, vector(0.1f, 0.4f), 1, 1)); + flush(); + + // update the best vector to make it the worst, with non-static rows + execute(insert, row(1, vector(0.1f, 0.5f), 1, 0)); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(2), row(3)); + + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1"), row(3)); + + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1"), row(2)); + }); + } + + @Test + public void testStaticColumnWithDifferentSourcesWithoutRegulars() throws Throwable + { + createTable("CREATE TABLE %s (k int, c int, s vector static, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(s) USING 'StorageAttachedIndex'"); + + // insert static rows alone, without non-static rows + String insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(1, vector(0.1f, 0.4f))); + execute(insert, row(2, vector(0.1f, 0.3f))); + execute(insert, row(3, vector(0.1f, 0.2f))); + execute(insert, row(4, vector(0.1f, 0.1f))); + flush(); + + // update the best vector to make it the worst, without non-static rows + execute(insert, row(4, vector(0.1f, 0.5f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(3), row(2)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.1f))); + execute(insert, row(2, vector(0.1f, 0.2f))); + execute(insert, row(3, vector(0.1f, 0.3f))); + execute(insert, row(4, vector(0.1f, 0.4f))); + flush(); + + // update the best vector to make it the worst, without non-static rows + execute(insert, row(1, vector(0.1f, 0.5f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(2), row(3)); + }); + } + + @Test + public void testStaticColumnWithDifferentSourcesWithAndWithoutRegulars() throws Throwable + { + createTable("CREATE TABLE %s (k int, c int, s vector static, v int, PRIMARY KEY(k, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(s) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // insert static rows alone, with non-static rows + String insert = "INSERT INTO %s (k, s, c, v) VALUES (?, ?, ?, ?)"; + execute(insert, row(1, vector(0.1f, 0.4f), 1, 0)); + execute(insert, row(2, vector(0.1f, 0.3f), 1, 1)); + execute(insert, row(3, vector(0.1f, 0.2f), 1, 0)); + execute(insert, row(4, vector(0.1f, 0.1f), 1, 1)); + flush(); + + // update the best vector to make it the worst, without non-static rows + insert = "INSERT INTO %s (k, s) VALUES (?, ?)"; + execute(insert, row(4, vector(0.1f, 0.5f))); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(2), row(1), row(4)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(3), row(2)); + + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(3), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1"), row(3)); + + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(4)); + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1"), row(2)); + }); + + // test again with a different order + execute("TRUNCATE TABLE %s"); + execute(insert, row(1, vector(0.1f, 0.1f))); + execute(insert, row(2, vector(0.1f, 0.2f))); + execute(insert, row(3, vector(0.1f, 0.3f))); + execute(insert, row(4, vector(0.1f, 0.4f))); + flush(); + + // update the best vector to make it the worst, without non-static rows + insert = "INSERT INTO %s (k, s, c, v) VALUES (?, ?, ?, ?)"; + execute(insert, row(1, vector(0.1f, 0.5f), 1, 0)); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(2), row(3), row(4), row(1)); + assertRows(execute("SELECT k FROM %s ORDER BY s ANN OF [0.1, 0.1] LIMIT 2"), row(2), row(3)); + + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10"), row(1)); + assertRows(execute("SELECT k FROM %s WHERE v=0 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1"), row(1)); + + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 10")); + assertRows(execute("SELECT k FROM %s WHERE v=1 ORDER BY s ANN OF [0.1, 0.1] LIMIT 1")); + }); + } +}