Skip to content

CNDB-13952: Handle Chronicle Map entry overflow in vector index compaction #1731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ public static RemappedPostings describeForCompaction(Structure structure, int gr
int maxOldOrdinal = Integer.MIN_VALUE;
int maxRow = Integer.MIN_VALUE;
var extraOrdinals = new Int2IntHashMap(Integer.MIN_VALUE);
for (var entry : postingsMap.entrySet())
for (var postings : postingsMap.values())
{
var postings = entry.getValue();
int ordinal = postings.getOrdinal();

maxOldOrdinal = Math.max(maxOldOrdinal, ordinal);
Expand Down Expand Up @@ -470,8 +469,20 @@ private static RemappedPostings createGenericRenumberedMapping(Set<Integer> live
*/
public static <T> RemappedPostings createGenericIdentityMapping(Map<VectorFloat<?>, ? extends VectorPostings<T>> postingsMap)
{
var maxOldOrdinal = postingsMap.values().stream().mapToInt(VectorPostings::getOrdinal).max().orElseThrow();
int maxRow = postingsMap.values().stream().flatMap(p -> p.getRowIds().stream()).mapToInt(i -> i).max().orElseThrow();
// It can be expensive to iterate over the postings map. We do it once to get the max ordinal and then
// again to build the bitset.
int maxOldOrdinal = Integer.MIN_VALUE;
int maxRow = Integer.MIN_VALUE;
for (var postings : postingsMap.values())
{
maxOldOrdinal = max(maxOldOrdinal, postings.getOrdinal());
for (int rowId : postings.getRowIds())
maxRow = max(maxRow, rowId);
}

assert maxOldOrdinal >= 0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the previous behavior with "orElseThrow" was to throw if the collection was empty

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this now throws an IllegalStateException if the values were not set.

assert maxRow >= 0;

var presentOrdinals = new FixedBitSet(maxOldOrdinal + 1);
for (var entry : postingsMap.entrySet())
presentOrdinals.set(entry.getValue().getOrdinal());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ public class CompactionGraph implements Closeable, Accountable
// if `useSyntheticOrdinals` is true then we use `nextOrdinal` to avoid holes, otherwise use rowId as source of ordinals
private final boolean useSyntheticOrdinals;
private int nextOrdinal = 0;
private int postingsCount = 0;

// Used to force flush on next add
private boolean requiresFlush = false;

// protects the fine-tuning changes (done in maybeAddVector) from addGraphNode threads
// (and creates happens-before events so we don't need to mark the other fields volatile)
Expand Down Expand Up @@ -245,7 +249,8 @@ public int size()

public boolean isEmpty()
{
return postingsMap.values().stream().allMatch(VectorPostings::isEmpty);
// This relies on the fact that compaction never has vectors pointing to empty postings lists.
return postingsMap.isEmpty();
}

/**
Expand All @@ -270,6 +275,9 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws
return new InsertionResult(0);
}

// At this point, we'll add the posting, so it's safe to count it.
postingsCount++;

// if we don't see sequential rowids, it means the skipped row(s) have null vectors
if (segmentRowId != lastRowId + 1)
postingsStructure = Structure.ZERO_OR_ONE_TO_MANY;
Expand Down Expand Up @@ -347,7 +355,7 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws
var newPosting = postings.add(segmentRowId);
assert newPosting;
bytesUsed += postings.bytesPerPosting();
postingsMap.put(vector, postings); // re-serialize to disk
requiresFlush = safePut(postingsMap, vector, postings); // re-serialize to disk

return new InsertionResult(bytesUsed);
}
Expand Down Expand Up @@ -381,8 +389,7 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
postingsMap.keySet().size(), builder.getGraph().size());
if (logger.isDebugEnabled())
{
logger.debug("Writing graph with {} rows and {} distinct vectors",
postingsMap.values().stream().mapToInt(VectorPostings::size).sum(), builder.getGraph().size());
logger.debug("Writing graph with {} rows and {} distinct vectors", postingsCount, builder.getGraph().size());
logger.debug("Estimated size is {} + {}", compressedVectors.ramBytesUsed(), builder.getGraph().ramBytesUsed());
}

Expand Down Expand Up @@ -476,7 +483,26 @@ public long ramBytesUsed()

public boolean requiresFlush()
{
return builder.getGraph().size() >= postingsEntriesAllocated;
return builder.getGraph().size() >= postingsEntriesAllocated || requiresFlush;
}

static <T> boolean safePut(ChronicleMap<T, CompactionVectorPostings> map, T key, CompactionVectorPostings value)
{
try
{
map.put(key, value);
return false;
}
catch (IllegalArgumentException e)
{
logger.error("Error serializing postings to disk, will reattempt with compression", e);
// This is an extreme edge case where there are many duplicate vectors. This naive approach
// means that we might have a smaller vector graph than desired, but at least we will not
// fail to build the index.
value.setShouldCompress(true);
map.put(key, value);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible it sill fails after compression? If so, then what? Maybe we should still provide at least some diagnostic message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given my testing and our usage, it's unlikely to fail a second time. We add postings to this map one row at a time on a single thread, so when we cross the threshold for max postings in an array, we do so by 4 bytes. I'll update the debug log line above since that'll likely be sufficient. The other option is to see if there is a better data structure fore us. I heard there was a lucene option that might handle this more gracefully without special encoding.

return true;
}
}

private static class VectorFloatMarshaller implements BytesReader<VectorFloat<?>>, BytesWriter<VectorFloat<?>> {
Expand Down
120 changes: 107 additions & 13 deletions src/java/org/apache/cassandra/index/sai/disk/vector/VectorPostings.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

package org.apache.cassandra.index.sai.disk.vector;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Function;
import java.util.function.ToIntFunction;

import com.google.common.base.Preconditions;
Expand All @@ -30,6 +30,8 @@
import net.openhft.chronicle.hash.serialization.BytesReader;
import net.openhft.chronicle.hash.serialization.BytesWriter;
import org.agrona.collections.IntArrayList;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;

public class VectorPostings<T>
{
Expand Down Expand Up @@ -160,7 +162,10 @@ public int getOrdinal(boolean assertSet)
return ordinal;
}

public static class CompactionVectorPostings extends VectorPostings<Integer> {
public static class CompactionVectorPostings extends VectorPostings<Integer>
{
private volatile boolean shouldCompress = false;

public CompactionVectorPostings(int ordinal, List<Integer> raw)
{
super(raw);
Expand All @@ -179,6 +184,11 @@ public void setOrdinal(int ordinal)
throw new UnsupportedOperationException();
}

public void setShouldCompress(boolean shouldCompress)
{
this.shouldCompress = shouldCompress;
}

@Override
public IntArrayList getRowIds()
{
Expand All @@ -193,7 +203,7 @@ public IntArrayList getRowIds()
public long bytesPerPosting()
{
long REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
return REF_BYTES + Integer.BYTES;
return REF_BYTES + Integer.BYTES + 1; // 1 byte for boolean
}
}

Expand All @@ -203,8 +213,24 @@ static class Marshaller implements BytesReader<CompactionVectorPostings>, BytesW
public void write(Bytes out, CompactionVectorPostings postings) {
out.writeInt(postings.ordinal);
out.writeInt(postings.size());
for (Integer posting : postings.getPostings()) {
out.writeInt(posting);
out.writeBoolean(postings.shouldCompress);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing: doesn't it break backwards compatibility? The format is different now. Shouldn't we bump up the version?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is a temporary file that does not survive to the next instantiation of the JVM on the same node. See:

// the extension here is important to signal to CFS.scrubDataDirectories that it should be removed if present at restart
Component tmpComponent = new Component(Component.Type.CUSTOM, "chronicle" + Descriptor.TMP_EXT);
postingsFile = dd.fileFor(tmpComponent);
postingsMap = ChronicleMapBuilder.of((Class<VectorFloat<?>>) (Class) VectorFloat.class, (Class<CompactionVectorPostings>) (Class) CompactionVectorPostings.class)
.averageKeySize(dimension * Float.BYTES)
.averageValueSize(VectorPostings.emptyBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 2 * Integer.BYTES)
.keyMarshaller(new VectorFloatMarshaller())
.valueMarshaller(new VectorPostings.Marshaller())
.entries(postingsEntriesAllocated)
.createPersistedTo(postingsFile.toJavaIOFile());

We use the Descriptor.TMP_EXT file extension to ensure that the file is removed if present at restart.

if (postings.shouldCompress)
{
try
{
BytesDataOutput writer = new BytesDataOutput(out);
for (int posting : postings.getPostings())
writer.writeVInt(posting);
Comment on lines +222 to +223
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those postings sorted? Just an idea: if they are sorted, maybe better to use delta-encoding (and if they are not sorted, maybe we could sort them?). Deltas would be usually smaller, hence vints would be smaller as well. Especially if there are lot of duplicates, you'd get many zeroes which compress down to 1 byte. Looks like you don't need a random access to the middle of a posting list on disk, but you deserialize all at once.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are sorted, yes. We could consider delta encoding too. There are not expected to be duplicates though because a row has at most one vector when constructing a graph during compaction. And you're correct that we consume the list entirely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main reason for not going further in the compression is that we flush immediately after hitting this code block, and we only need to find 4 bytes of savings to prevent a subsequent failure on put.

}
catch (Exception e)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(here and below) catching Exception is usually a code smell, and then rethrowing as unchecked RuntimeException is also bad
do we have a better way to rethrow ?

should we also handle InterruptedException ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated it to handle only IOException. We do not expect to hit that condition though. We have this exception because I used the lucene DataOutput class to handle the vint serde.

{
throw new RuntimeException(e);
}
}
else
{
for (Integer posting : postings.getPostings())
out.writeInt(posting);
}
}

Expand All @@ -213,20 +239,88 @@ public CompactionVectorPostings read(Bytes in, CompactionVectorPostings using) {
int ordinal = in.readInt();
int size = in.readInt();
assert size >= 0 : size;
CompactionVectorPostings cvp;
if (size == 1) {
cvp = new CompactionVectorPostings(ordinal, in.readInt());
boolean isCompressed = in.readBoolean();
if (isCompressed)
{
try
{
BytesDataInput reader = new BytesDataInput(in);
var postingsList = new IntArrayList(size, -1);
for (int i = 0; i < size; i++)
postingsList.add(reader.readVInt());

return new CompactionVectorPostings(ordinal, postingsList);
}
catch (Exception e)
{
throw new RuntimeException(e);
}
}
else
{
var postingsList = new IntArrayList(size, -1);
for (int i = 0; i < size; i++)
if (size == 1)
{
postingsList.add(in.readInt());
return new CompactionVectorPostings(ordinal, in.readInt());
}
cvp = new CompactionVectorPostings(ordinal, postingsList);
else
{
var postingsList = new IntArrayList(size, -1);
for (int i = 0; i < size; i++)
postingsList.add(in.readInt());

return new CompactionVectorPostings(ordinal, postingsList);
}
}
}

private static class BytesDataOutput extends DataOutput
{
private final Bytes<?> bytes;

public BytesDataOutput(Bytes<?> bytes)
{
this.bytes = bytes;
}

@Override
public void writeByte(byte b) throws IOException

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this cannot throw IOException, maybe we can remove the "throws" clause and simplify the code above ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right that we can (and should) remove this, but it doesn't fix the above code because the class has an exception in the readVInt() method signature.

{
bytes.writeByte(b);
}

@Override
public void writeBytes(byte[] b, int off, int len) throws IOException
{
bytes.write(b, off, len);
}
}

private static class BytesDataInput extends DataInput
{
private final Bytes<?> bytes;

public BytesDataInput(Bytes<?> bytes)
{
this.bytes = bytes;
}

@Override
public byte readByte() throws IOException
{
return bytes.readByte();
}

@Override
public void readBytes(byte[] b, int off, int len)
{
bytes.read(b, off, len);
}

@Override
public void skipBytes(long l)
{
bytes.readSkip(l);
}
return cvp;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.cassandra.index.sai.disk.vector;

import java.util.List;

import org.junit.Test;

import io.github.jbellis.jvector.util.RamUsageEstimator;
import net.openhft.chronicle.map.ChronicleMapBuilder;
import org.apache.cassandra.io.util.File;
import org.apache.cassandra.io.util.FileUtils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

public class CompactionGraphTest
{

@Test
public void test10Entries() throws Exception
{
// 10 entries means we hit the limit sooner
testEntries(10, 1000, 1);
}

@Test
public void test1MEntries() throws Exception
{
// more entries means it takes longer, but we hit this bug in prod before, so it is worth testing
testEntries(1000000, 5000, 100);
}

@Test
public void test50MEntries() throws Exception
{
// more entries means it takes longer, but we hit this bug in prod before, so it is worth testing
testEntries(50000000, 5000, 100);
}

// Callers of this method are expected to provide enought iterations and postings added per iteration
// to hit the entry size limit without exceeding it too much. Note that we add postings one at a time in the
// compaction graph, so we only ever increment by 4 bytes each time we attempt to re-serialize the entry.
private void testEntries(int entries, int iterations, int postingsAddedPerIteration) throws Exception
{
File postingsFile = FileUtils.createTempFile("testfile", "tmp");
try(var postingsMap = ChronicleMapBuilder.of(Integer.class, (Class<VectorPostings.CompactionVectorPostings>) (Class) VectorPostings.CompactionVectorPostings.class)
.averageValueSize(VectorPostings.emptyBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 2 * Integer.BYTES)
.valueMarshaller(new VectorPostings.Marshaller())
.entries(entries)
.createPersistedTo(postingsFile.toJavaIOFile()))
{
postingsMap.put(0, new VectorPostings.CompactionVectorPostings(0, List.of()));
int rowId = 0;
boolean recoveredFromFailure = false;
for (int i = 0; i < iterations; i++)
{
// Iterate so we can fail sooner
var existing = postingsMap.get(0);
for (int j = 0; j < postingsAddedPerIteration; j++)
existing.add(rowId++);

recoveredFromFailure = CompactionGraph.safePut(postingsMap, 0, existing);
if (recoveredFromFailure)
break;
}

assertTrue("Failed to hit entry size limit", recoveredFromFailure);
// Validate that we can read (deserialize) the entry
var existing = postingsMap.get(0);
assertNotNull(existing);
assertEquals(rowId, existing.size());
// Validate that the row ids are correct
var rowIds = existing.getRowIds();
for (int i = 0; i < rowId; i++)
assertEquals(i, (int) rowIds.get(i));
}
}
}