Skip to content

CNDB-13952: Handle Chronicle Map entry overflow in vector index compaction #1731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ public static RemappedPostings describeForCompaction(Structure structure, int gr
int maxOldOrdinal = Integer.MIN_VALUE;
int maxRow = Integer.MIN_VALUE;
var extraOrdinals = new Int2IntHashMap(Integer.MIN_VALUE);
for (var entry : postingsMap.entrySet())
for (var postings : postingsMap.values())
{
var postings = entry.getValue();
int ordinal = postings.getOrdinal();

maxOldOrdinal = Math.max(maxOldOrdinal, ordinal);
Expand Down Expand Up @@ -470,8 +469,20 @@ private static RemappedPostings createGenericRenumberedMapping(Set<Integer> live
*/
public static <T> RemappedPostings createGenericIdentityMapping(Map<VectorFloat<?>, ? extends VectorPostings<T>> postingsMap)
{
var maxOldOrdinal = postingsMap.values().stream().mapToInt(VectorPostings::getOrdinal).max().orElseThrow();
int maxRow = postingsMap.values().stream().flatMap(p -> p.getRowIds().stream()).mapToInt(i -> i).max().orElseThrow();
// It can be expensive to iterate over the postings map. We do it once to get the max ordinal and then
// again to build the bitset.
int maxOldOrdinal = Integer.MIN_VALUE;
int maxRow = Integer.MIN_VALUE;
for (var postings : postingsMap.values())
{
maxOldOrdinal = max(maxOldOrdinal, postings.getOrdinal());
for (int rowId : postings.getRowIds())
maxRow = max(maxRow, rowId);
}

if (maxOldOrdinal < 0 || maxRow < 0)
throw new IllegalStateException("maxOldOrdinal or maxRow is negative: " + maxOldOrdinal + ' ' + maxRow);

var presentOrdinals = new FixedBitSet(maxOldOrdinal + 1);
for (var entry : postingsMap.entrySet())
presentOrdinals.set(entry.getValue().getOrdinal());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ public class CompactionGraph implements Closeable, Accountable
// if `useSyntheticOrdinals` is true then we use `nextOrdinal` to avoid holes, otherwise use rowId as source of ordinals
private final boolean useSyntheticOrdinals;
private int nextOrdinal = 0;
private int postingsCount = 0;

// Used to force flush on next add
private boolean requiresFlush = false;

// protects the fine-tuning changes (done in maybeAddVector) from addGraphNode threads
// (and creates happens-before events so we don't need to mark the other fields volatile)
Expand Down Expand Up @@ -245,7 +249,8 @@ public int size()

public boolean isEmpty()
{
return postingsMap.values().stream().allMatch(VectorPostings::isEmpty);
// This relies on the fact that compaction never has vectors pointing to empty postings lists.
return postingsMap.isEmpty();
}

/**
Expand All @@ -270,6 +275,9 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws
return new InsertionResult(0);
}

// At this point, we'll add the posting, so it's safe to count it.
postingsCount++;

// if we don't see sequential rowids, it means the skipped row(s) have null vectors
if (segmentRowId != lastRowId + 1)
postingsStructure = Structure.ZERO_OR_ONE_TO_MANY;
Expand Down Expand Up @@ -347,7 +355,7 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws
var newPosting = postings.add(segmentRowId);
assert newPosting;
bytesUsed += postings.bytesPerPosting();
postingsMap.put(vector, postings); // re-serialize to disk
requiresFlush = safePut(postingsMap, vector, postings); // re-serialize to disk

return new InsertionResult(bytesUsed);
}
Expand Down Expand Up @@ -381,8 +389,7 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
postingsMap.keySet().size(), builder.getGraph().size());
if (logger.isDebugEnabled())
{
logger.debug("Writing graph with {} rows and {} distinct vectors",
postingsMap.values().stream().mapToInt(VectorPostings::size).sum(), builder.getGraph().size());
logger.debug("Writing graph with {} rows and {} distinct vectors", postingsCount, builder.getGraph().size());
logger.debug("Estimated size is {} + {}", compressedVectors.ramBytesUsed(), builder.getGraph().ramBytesUsed());
}

Expand Down Expand Up @@ -476,7 +483,27 @@ public long ramBytesUsed()

public boolean requiresFlush()
{
return builder.getGraph().size() >= postingsEntriesAllocated;
return builder.getGraph().size() >= postingsEntriesAllocated || requiresFlush;
}

static <T> boolean safePut(ChronicleMap<T, CompactionVectorPostings> map, T key, CompactionVectorPostings value)
{
try
{
map.put(key, value);
return false;
}
catch (IllegalArgumentException e)
{
logger.debug("Error serializing postings to disk, will reattempt with compression. Vector {} had {} postings",
key, value.size(), e);
// This is an extreme edge case where there are many duplicate vectors. This naive approach
// means that we might have a smaller vector graph than desired, but at least we will not
// fail to build the index.
value.setShouldCompress(true);
map.put(key, value);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible it sill fails after compression? If so, then what? Maybe we should still provide at least some diagnostic message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given my testing and our usage, it's unlikely to fail a second time. We add postings to this map one row at a time on a single thread, so when we cross the threshold for max postings in an array, we do so by 4 bytes. I'll update the debug log line above since that'll likely be sufficient. The other option is to see if there is a better data structure fore us. I heard there was a lucene option that might handle this more gracefully without special encoding.

return true;
}
}

private static class VectorFloatMarshaller implements BytesReader<VectorFloat<?>>, BytesWriter<VectorFloat<?>> {
Expand Down
122 changes: 109 additions & 13 deletions src/java/org/apache/cassandra/index/sai/disk/vector/VectorPostings.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

package org.apache.cassandra.index.sai.disk.vector;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Function;
import java.util.function.ToIntFunction;

import com.google.common.base.Preconditions;
Expand All @@ -30,6 +30,8 @@
import net.openhft.chronicle.hash.serialization.BytesReader;
import net.openhft.chronicle.hash.serialization.BytesWriter;
import org.agrona.collections.IntArrayList;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;

public class VectorPostings<T>
{
Expand Down Expand Up @@ -160,7 +162,10 @@ public int getOrdinal(boolean assertSet)
return ordinal;
}

public static class CompactionVectorPostings extends VectorPostings<Integer> {
public static class CompactionVectorPostings extends VectorPostings<Integer>
{
private volatile boolean shouldCompress = false;

public CompactionVectorPostings(int ordinal, List<Integer> raw)
{
super(raw);
Expand All @@ -179,6 +184,11 @@ public void setOrdinal(int ordinal)
throw new UnsupportedOperationException();
}

public void setShouldCompress(boolean shouldCompress)
{
this.shouldCompress = shouldCompress;
}

@Override
public IntArrayList getRowIds()
{
Expand All @@ -193,7 +203,7 @@ public IntArrayList getRowIds()
public long bytesPerPosting()
{
long REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
return REF_BYTES + Integer.BYTES;
return REF_BYTES + Integer.BYTES + 1; // 1 byte for boolean
}
}

Expand All @@ -203,8 +213,25 @@ static class Marshaller implements BytesReader<CompactionVectorPostings>, BytesW
public void write(Bytes out, CompactionVectorPostings postings) {
out.writeInt(postings.ordinal);
out.writeInt(postings.size());
for (Integer posting : postings.getPostings()) {
out.writeInt(posting);
out.writeBoolean(postings.shouldCompress);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing: doesn't it break backwards compatibility? The format is different now. Shouldn't we bump up the version?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is a temporary file that does not survive to the next instantiation of the JVM on the same node. See:

// the extension here is important to signal to CFS.scrubDataDirectories that it should be removed if present at restart
Component tmpComponent = new Component(Component.Type.CUSTOM, "chronicle" + Descriptor.TMP_EXT);
postingsFile = dd.fileFor(tmpComponent);
postingsMap = ChronicleMapBuilder.of((Class<VectorFloat<?>>) (Class) VectorFloat.class, (Class<CompactionVectorPostings>) (Class) CompactionVectorPostings.class)
.averageKeySize(dimension * Float.BYTES)
.averageValueSize(VectorPostings.emptyBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 2 * Integer.BYTES)
.keyMarshaller(new VectorFloatMarshaller())
.valueMarshaller(new VectorPostings.Marshaller())
.entries(postingsEntriesAllocated)
.createPersistedTo(postingsFile.toJavaIOFile());

We use the Descriptor.TMP_EXT file extension to ensure that the file is removed if present at restart.

if (postings.shouldCompress)
{
try
{
BytesDataOutput writer = new BytesDataOutput(out);
for (int posting : postings.getPostings())
writer.writeVInt(posting);
Comment on lines +222 to +223
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those postings sorted? Just an idea: if they are sorted, maybe better to use delta-encoding (and if they are not sorted, maybe we could sort them?). Deltas would be usually smaller, hence vints would be smaller as well. Especially if there are lot of duplicates, you'd get many zeroes which compress down to 1 byte. Looks like you don't need a random access to the middle of a posting list on disk, but you deserialize all at once.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are sorted, yes. We could consider delta encoding too. There are not expected to be duplicates though because a row has at most one vector when constructing a graph during compaction. And you're correct that we consume the list entirely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main reason for not going further in the compression is that we flush immediately after hitting this code block, and we only need to find 4 bytes of savings to prevent a subsequent failure on put.

}
catch (IOException e)
{
// Not be reachable because the Bytes out object does not throw an exception on write.
throw new RuntimeException(e);
}
}
else
{
for (Integer posting : postings.getPostings())
out.writeInt(posting);
}
}

Expand All @@ -213,20 +240,89 @@ public CompactionVectorPostings read(Bytes in, CompactionVectorPostings using) {
int ordinal = in.readInt();
int size = in.readInt();
assert size >= 0 : size;
CompactionVectorPostings cvp;
if (size == 1) {
cvp = new CompactionVectorPostings(ordinal, in.readInt());
boolean isCompressed = in.readBoolean();
if (isCompressed)
{
try
{
BytesDataInput reader = new BytesDataInput(in);
var postingsList = new IntArrayList(size, -1);
for (int i = 0; i < size; i++)
postingsList.add(reader.readVInt());

return new CompactionVectorPostings(ordinal, postingsList);
}
catch (IOException e)
{
// Not be reachable because the Bytes in object does not throw an exception on read.
throw new RuntimeException(e);
}
}
else
{
var postingsList = new IntArrayList(size, -1);
for (int i = 0; i < size; i++)
if (size == 1)
{
postingsList.add(in.readInt());
return new CompactionVectorPostings(ordinal, in.readInt());
}
cvp = new CompactionVectorPostings(ordinal, postingsList);
else
{
var postingsList = new IntArrayList(size, -1);
for (int i = 0; i < size; i++)
postingsList.add(in.readInt());

return new CompactionVectorPostings(ordinal, postingsList);
}
}
}

private static class BytesDataOutput extends DataOutput
{
private final Bytes<?> bytes;

public BytesDataOutput(Bytes<?> bytes)
{
this.bytes = bytes;
}

@Override
public void writeByte(byte b)
{
bytes.writeByte(b);
}

@Override
public void writeBytes(byte[] b, int off, int len)
{
bytes.write(b, off, len);
}
}

private static class BytesDataInput extends DataInput
{
private final Bytes<?> bytes;

public BytesDataInput(Bytes<?> bytes)
{
this.bytes = bytes;
}

@Override
public byte readByte()
{
return bytes.readByte();
}

@Override
public void readBytes(byte[] b, int off, int len)
{
bytes.read(b, off, len);
}

@Override
public void skipBytes(long l)
{
bytes.readSkip(l);
}
return cvp;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.cassandra.index.sai.disk.vector;

import java.util.List;

import org.junit.Test;

import io.github.jbellis.jvector.util.RamUsageEstimator;
import net.openhft.chronicle.map.ChronicleMapBuilder;
import org.apache.cassandra.io.util.File;
import org.apache.cassandra.io.util.FileUtils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

public class CompactionGraphTest
{

@Test
public void test10Entries() throws Exception
{
// 10 entries means we hit the limit sooner
testEntries(10, 1000, 1);
}

@Test
public void test1MEntries() throws Exception
{
// more entries means it takes longer, but we hit this bug in prod before, so it is worth testing
testEntries(1000000, 5000, 100);
}

@Test
public void test50MEntries() throws Exception
{
// more entries means it takes longer, but we hit this bug in prod before, so it is worth testing
testEntries(50000000, 5000, 100);
}

// Callers of this method are expected to provide enough iterations and postings added per iteration
// to hit the entry size limit without exceeding it too much. Note that we add postings one at a time in the
// compaction graph, so we only ever increment by 4 bytes each time we attempt to re-serialize the entry.
private void testEntries(int entries, int iterations, int postingsAddedPerIteration) throws Exception
{
File postingsFile = FileUtils.createTempFile("testfile", "tmp");
try(var postingsMap = ChronicleMapBuilder.of(Integer.class, (Class<VectorPostings.CompactionVectorPostings>) (Class) VectorPostings.CompactionVectorPostings.class)
.averageValueSize(VectorPostings.emptyBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 2 * Integer.BYTES)
.valueMarshaller(new VectorPostings.Marshaller())
.entries(entries)
.createPersistedTo(postingsFile.toJavaIOFile()))
{
postingsMap.put(0, new VectorPostings.CompactionVectorPostings(0, List.of()));
int rowId = 0;
boolean recoveredFromFailure = false;
for (int i = 0; i < iterations; i++)
{
// Iterate so we can fail sooner
var existing = postingsMap.get(0);
for (int j = 0; j < postingsAddedPerIteration; j++)
existing.add(rowId++);

recoveredFromFailure = CompactionGraph.safePut(postingsMap, 0, existing);
if (recoveredFromFailure)
break;
}

assertTrue("Failed to hit entry size limit", recoveredFromFailure);
// Validate that we can read (deserialize) the entry
var existing = postingsMap.get(0);
assertNotNull(existing);
assertEquals(rowId, existing.size());
// Validate that the row ids are correct
var rowIds = existing.getRowIds();
for (int i = 0; i < rowId; i++)
assertEquals(i, (int) rowIds.get(i));
}
}
}