Skip to content

CNDB-14343: Fix ANN queries on primary key and static columns #1800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ public SSTableId<?> getSourceSstableId()
return sourceSstableId;
}

@Override
public PrimaryKeyWithSource forStaticRow()
{
return new PrimaryKeyWithSource(primaryKey().forStaticRow(),
sourceSstableId,
sourceRowId,
sourceSstableMinKey,
sourceSstableMaxKey);
}

@Override
public Token token()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ public PrimaryKey loadDeferred()
return this;
}

@Override
public PartitionAwarePrimaryKey forStaticRow()
{
return this;
}

@Override
public Token token()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ private RowAwarePrimaryKey(Token token, DecoratedKey partitionKey, Clustering cl
this.primaryKeySupplier = primaryKeySupplier;
}

@Override
public RowAwarePrimaryKey forStaticRow()
{
return new RowAwarePrimaryKey(token, partitionKey, Clustering.STATIC_CLUSTERING, primaryKeySupplier);
}

@Override
public Token token()
{
Expand Down Expand Up @@ -220,12 +226,12 @@ public long ramBytesUsed()
// Object header + 4 references (token, partitionKey, clustering, primaryKeySupplier) + implicit outer reference
long size = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
5L * RamUsageEstimator.NUM_BYTES_OBJECT_REF;

if (token != null)
size += token.getHeapSize();
if (partitionKey != null)
size += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + // token and key references
size += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
2L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + // token and key references
2L * Long.BYTES;
// We don't count clustering size here as it's managed elsewhere
return size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,15 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext cont
var relevantOrdinals = new IntHashSet();

var keysInRange = PrimaryKeyListUtil.getKeysInRange(keys, minimumKey, maximumKey);
boolean isStatic = indexContext.getDefinition().isStatic();

keysInRange.forEach(k ->
{
// if the indexed column is static, we need to get the static row associated with the non-static row that
// might be referenced by the key
if (isStatic)
k = k.forStaticRow();

var v = graph.vectorForKey(k);
if (v == null)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,12 @@ public UnfilteredRowIterator getPartition(PrimaryKey key, ColumnFamilyStore.View
// Class to transform the row to include its source table.
Function<Object, Transformation<BaseRowIterator<?>>> rowTransformer = (Object sourceTable) -> new Transformation<>()
{
@Override
protected Row applyToStatic(Row row)
{
return new RowWithSourceTable(row, sourceTable);
}

@Override
protected Row applyToRow(Row row)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,18 +698,35 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List<Primar
{
queryContext.addPartitionsRead(1);
queryContext.checkpoint();
var staticRow = partition.staticRow();
UnfilteredRowIterator clusters = applyIndexFilter(partition, filterTree, queryContext);

if (clusters == null || !clusters.hasNext())
{
processedKeys.add(pk);
if (clusters == null)
return null;
}

var now = FBUtilities.nowInSeconds();
var staticRow = partition.staticRow();
boolean isStaticValid = false;

// Each of the primary keys are equal, but they have different source tables.
// Therefore, we check to see if the static row is valid for any of them.
for (PrimaryKeyWithSortKey sourceKey : sourceKeys)
{
if (sourceKey.isIndexDataValid(staticRow, now))
{
// If there are no regular rows, return the static row only
if (!clusters.hasNext())
return new PrimaryKeyIterator(partition, staticRow, null, sourceKey, syntheticScoreColumn);

isStaticValid = true;
break;
}
}

// If the static row isn't valid, we can skip the partition.
if (!isStaticValid)
return null;

var row = clusters.next();
assert !clusters.hasNext() : "Expected only one row per partition";
if (!row.isRangeTombstoneMarker())
{
for (PrimaryKeyWithSortKey sourceKey : sourceKeys)
Expand Down Expand Up @@ -744,9 +761,15 @@ public void close()
public static class PrimaryKeyIterator extends AbstractUnfilteredRowIterator
{
private boolean consumed = false;

@Nullable
private final Unfiltered row;

public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfiltered content, PrimaryKeyWithSortKey primaryKeyWithSortKey, ColumnMetadata syntheticScoreColumn)
public PrimaryKeyIterator(UnfilteredRowIterator partition,
Row staticRow,
@Nullable Unfiltered content,
PrimaryKeyWithSortKey primaryKeyWithSortKey,
ColumnMetadata syntheticScoreColumn)
{
super(partition.metadata(),
partition.partitionKey(),
Expand All @@ -756,7 +779,7 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt
partition.isReverseOrder(),
partition.stats());

if (!content.isRow() || !(primaryKeyWithSortKey instanceof PrimaryKeyWithScore))
if (content == null || !content.isRow() || !(primaryKeyWithSortKey instanceof PrimaryKeyWithScore))
{
this.row = content;
return;
Expand Down Expand Up @@ -791,7 +814,7 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt
@Override
protected Unfiltered computeNext()
{
if (consumed)
if (consumed || row == null)
return endOfData();
consumed = true;
return row;
Expand Down
40 changes: 23 additions & 17 deletions src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.cql3.Ordering;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.ColumnFamilyStore;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.Keyspace;
Expand All @@ -46,6 +47,7 @@
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.db.partitions.PartitionIterator;
import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator;
import org.apache.cassandra.db.rows.BTreeRow;
import org.apache.cassandra.db.rows.BaseRowIterator;
import org.apache.cassandra.db.rows.Cell;
import org.apache.cassandra.db.rows.Row;
Expand All @@ -67,16 +69,19 @@

/**
* Processor applied to SAI based ORDER BY queries.
*
* * On a replica:
* * - filter(ScoreOrderedResultRetriever) is used to collect up to the top-K rows.
* * - We store any tombstones as well, to avoid losing them during coordinator reconciliation.
* * - The result is returned in PK order so that coordinator can merge from multiple replicas.
*
* </p>
* On a replica:
* <ul>
* <li>filter(ScoreOrderedResultRetriever) is used to collect up to the top-K rows.</li>
* <li>We store any tombstones as well, to avoid losing them during coordinator reconciliation.</li>
* <li>The result is returned in PK order so that coordinator can merge from multiple replicas.</li>
* </ul>
* On a coordinator:
* - reorder(PartitionIterator) is used to consume all rows from the provided partitions,
* compute the order based on either a column ordering or a similarity score, and keep top-K.
* - The result is returned in score/sortkey order.
* <ul>
* <li>reorder(PartitionIterator) is used to consume all rows from the provided partitions,
* compute the order based on either a column ordering or a similarity score, and keep top-K.</li>
* <li>The result is returned in score/sortkey order.</li>
* </ul>
*/
public class TopKProcessor
{
Expand Down Expand Up @@ -104,7 +109,7 @@ public TopKProcessor(ReadCommand command)

this.indexContext = indexAndExpression.left;
this.expression = indexAndExpression.right;
if (expression.operator() == Operator.ANN && !Ordering.Ann.useSyntheticScore())
if (expression.operator() == Operator.ANN && !(Ordering.Ann.useSyntheticScore() && expression.column().isRegular()))
this.queryVector = vts.createFloatVector(TypeUtil.decomposeVector(indexContext, expression.getIndexValue().duplicate()));
else
this.queryVector = null;
Expand Down Expand Up @@ -253,6 +258,9 @@ private PartitionResults processScoredPartition(BaseRowIterator<?> partitionRowI
float keyAndStaticScore = getScoreForRow(key, staticRow);
var pr = new PartitionResults(partitionInfo);

if (!partitionRowIterator.hasNext())
pr.addRow(Triple.of(partitionInfo, BTreeRow.emptyRow(Clustering.EMPTY), keyAndStaticScore));

while (partitionRowIterator.hasNext())
{
Unfiltered unfiltered = partitionRowIterator.next();
Expand All @@ -278,31 +286,29 @@ private PartitionResults processScoredPartition(BaseRowIterator<?> partitionRowI
private int processSingleRowPartition(TreeMap<PartitionInfo, TreeSet<Unfiltered>> unfilteredByPartition,
BaseRowIterator<?> partitionRowIterator)
{
if (!partitionRowIterator.hasNext())
return 0;

Unfiltered unfiltered = partitionRowIterator.next();
Unfiltered unfiltered = partitionRowIterator.hasNext() ? partitionRowIterator.next() : null;
assert !partitionRowIterator.hasNext() : "Only one row should be returned";
// Always include tombstones for coordinator. It relies on ReadCommand#withMetricsRecording to throw
// TombstoneOverwhelmingException to prevent OOM.
PartitionInfo partitionInfo = PartitionInfo.create(partitionRowIterator);
addUnfiltered(unfilteredByPartition, partitionInfo, unfiltered);
return unfiltered.isRangeTombstoneMarker() ? 0 : 1;
return unfiltered != null && unfiltered.isRangeTombstoneMarker() ? 0 : 1;
}

private void addUnfiltered(SortedMap<PartitionInfo, TreeSet<Unfiltered>> unfilteredByPartition,
PartitionInfo partitionInfo,
Unfiltered unfiltered)
{
var map = unfilteredByPartition.computeIfAbsent(partitionInfo, k -> new TreeSet<>(command.metadata().comparator));
map.add(unfiltered);
if (unfiltered != null)
map.add(unfiltered);
}

private float getScoreForRow(DecoratedKey key, Row row)
{
ColumnMetadata column = indexContext.getDefinition();

if (column.isPrimaryKeyColumn() && key == null)
if (column.isPartitionKey() && key == null)
return 0;

if (column.isStatic() && !row.isStatic())
Expand Down
7 changes: 7 additions & 0 deletions src/java/org/apache/cassandra/index/sai/utils/PrimaryKey.java
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ static Factory factory(ClusteringComparator clusteringComparator, IndexFeatureSe
: new PartitionAwarePrimaryKeyFactory();
}

/**
* Returns a {@link PrimaryKey} to fetch the static row of the partition associated with this primary key.
*
* @return a {@link PrimaryKey} for the static row
*/
PrimaryKey forStaticRow();

/**
* Returns the {@link Token} associated with this primary key.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,26 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey

public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable)
{
super(context, sourceTable, primaryKey);
this.byteComparable = byteComparable;
this(context, (Object) sourceTable, primaryKey, byteComparable);
}

public PrimaryKeyWithByteComparable(IndexContext context, SSTableId<?> sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable)
{
this(context, (Object) sourceTable, primaryKey, byteComparable);
}

public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable)
private PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable)
{
super(context, sourceTable, primaryKey);
this.byteComparable = byteComparable;
}

@Override
public PrimaryKeyWithByteComparable forStaticRow()
{
return new PrimaryKeyWithByteComparable(context, sourceTable, primaryKey.forStaticRow(), byteComparable);
}

@Override
protected boolean isIndexDataEqualToLiveData(ByteBuffer value)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,26 @@ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey

public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore)
{
super(context, source, primaryKey);
this.indexScore = indexScore;
this(context, (Object) source, primaryKey, indexScore);
}

public PrimaryKeyWithScore(IndexContext context, SSTableId<?> source, PrimaryKey primaryKey, float indexScore)
{
this(context, (Object) source, primaryKey, indexScore);
}

public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore)
private PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore)
{
super(context, source, primaryKey);
this.indexScore = indexScore;
}

@Override
public PrimaryKeyWithScore forStaticRow()
{
return new PrimaryKeyWithScore(context, sourceTable, primaryKey.forStaticRow(), indexScore);
}

@Override
protected boolean isIndexDataEqualToLiveData(ByteBuffer value)
{
Expand Down
Loading