diff --git a/src/java/org/apache/cassandra/index/sai/disk/v5/V5VectorPostingsWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v5/V5VectorPostingsWriter.java index 27e2317b057f..7b35200ceff8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v5/V5VectorPostingsWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v5/V5VectorPostingsWriter.java @@ -129,9 +129,8 @@ public static RemappedPostings describeForCompaction(Structure structure, int gr int maxOldOrdinal = Integer.MIN_VALUE; int maxRow = Integer.MIN_VALUE; var extraOrdinals = new Int2IntHashMap(Integer.MIN_VALUE); - for (var entry : postingsMap.entrySet()) + for (var postings : postingsMap.values()) { - var postings = entry.getValue(); int ordinal = postings.getOrdinal(); maxOldOrdinal = Math.max(maxOldOrdinal, ordinal); @@ -470,8 +469,20 @@ private static RemappedPostings createGenericRenumberedMapping(Set live */ public static RemappedPostings createGenericIdentityMapping(Map, ? extends VectorPostings> postingsMap) { - var maxOldOrdinal = postingsMap.values().stream().mapToInt(VectorPostings::getOrdinal).max().orElseThrow(); - int maxRow = postingsMap.values().stream().flatMap(p -> p.getRowIds().stream()).mapToInt(i -> i).max().orElseThrow(); + // It can be expensive to iterate over the postings map. We do it once to get the max ordinal and then + // again to build the bitset. + int maxOldOrdinal = Integer.MIN_VALUE; + int maxRow = Integer.MIN_VALUE; + for (var postings : postingsMap.values()) + { + maxOldOrdinal = max(maxOldOrdinal, postings.getOrdinal()); + for (int rowId : postings.getRowIds()) + maxRow = max(maxRow, rowId); + } + + if (maxOldOrdinal < 0 || maxRow < 0) + throw new IllegalStateException("maxOldOrdinal or maxRow is negative: " + maxOldOrdinal + ' ' + maxRow); + var presentOrdinals = new FixedBitSet(maxOldOrdinal + 1); for (var entry : postingsMap.entrySet()) presentOrdinals.set(entry.getValue().getOrdinal()); diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java index 8e09abfbe0bc..8cdd20d093a4 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java @@ -127,6 +127,10 @@ public class CompactionGraph implements Closeable, Accountable // if `useSyntheticOrdinals` is true then we use `nextOrdinal` to avoid holes, otherwise use rowId as source of ordinals private final boolean useSyntheticOrdinals; private int nextOrdinal = 0; + private int postingsCount = 0; + + // Used to force flush on next add + private boolean requiresFlush = false; // protects the fine-tuning changes (done in maybeAddVector) from addGraphNode threads // (and creates happens-before events so we don't need to mark the other fields volatile) @@ -245,7 +249,8 @@ public int size() public boolean isEmpty() { - return postingsMap.values().stream().allMatch(VectorPostings::isEmpty); + // This relies on the fact that compaction never has vectors pointing to empty postings lists. + return postingsMap.isEmpty(); } /** @@ -270,6 +275,9 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws return new InsertionResult(0); } + // At this point, we'll add the posting, so it's safe to count it. + postingsCount++; + // if we don't see sequential rowids, it means the skipped row(s) have null vectors if (segmentRowId != lastRowId + 1) postingsStructure = Structure.ZERO_OR_ONE_TO_MANY; @@ -347,7 +355,7 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws var newPosting = postings.add(segmentRowId); assert newPosting; bytesUsed += postings.bytesPerPosting(); - postingsMap.put(vector, postings); // re-serialize to disk + requiresFlush = safePut(postingsMap, vector, postings); // re-serialize to disk return new InsertionResult(bytesUsed); } @@ -381,8 +389,7 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException postingsMap.keySet().size(), builder.getGraph().size()); if (logger.isDebugEnabled()) { - logger.debug("Writing graph with {} rows and {} distinct vectors", - postingsMap.values().stream().mapToInt(VectorPostings::size).sum(), builder.getGraph().size()); + logger.debug("Writing graph with {} rows and {} distinct vectors", postingsCount, builder.getGraph().size()); logger.debug("Estimated size is {} + {}", compressedVectors.ramBytesUsed(), builder.getGraph().ramBytesUsed()); } @@ -476,7 +483,27 @@ public long ramBytesUsed() public boolean requiresFlush() { - return builder.getGraph().size() >= postingsEntriesAllocated; + return builder.getGraph().size() >= postingsEntriesAllocated || requiresFlush; + } + + static boolean safePut(ChronicleMap map, T key, CompactionVectorPostings value) + { + try + { + map.put(key, value); + return false; + } + catch (IllegalArgumentException e) + { + logger.debug("Error serializing postings to disk, will reattempt with compression. Vector {} had {} postings", + key, value.size(), e); + // This is an extreme edge case where there are many duplicate vectors. This naive approach + // means that we might have a smaller vector graph than desired, but at least we will not + // fail to build the index. + value.setShouldCompress(true); + map.put(key, value); + return true; + } } private static class VectorFloatMarshaller implements BytesReader>, BytesWriter> { diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorPostings.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorPostings.java index f0fd04517a2d..c5786d4ef55a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorPostings.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorPostings.java @@ -18,9 +18,9 @@ package org.apache.cassandra.index.sai.disk.vector; +import java.io.IOException; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.Function; import java.util.function.ToIntFunction; import com.google.common.base.Preconditions; @@ -30,6 +30,8 @@ import net.openhft.chronicle.hash.serialization.BytesReader; import net.openhft.chronicle.hash.serialization.BytesWriter; import org.agrona.collections.IntArrayList; +import org.apache.lucene.store.DataInput; +import org.apache.lucene.store.DataOutput; public class VectorPostings { @@ -160,7 +162,10 @@ public int getOrdinal(boolean assertSet) return ordinal; } - public static class CompactionVectorPostings extends VectorPostings { + public static class CompactionVectorPostings extends VectorPostings + { + private volatile boolean shouldCompress = false; + public CompactionVectorPostings(int ordinal, List raw) { super(raw); @@ -179,6 +184,11 @@ public void setOrdinal(int ordinal) throw new UnsupportedOperationException(); } + public void setShouldCompress(boolean shouldCompress) + { + this.shouldCompress = shouldCompress; + } + @Override public IntArrayList getRowIds() { @@ -193,7 +203,7 @@ public IntArrayList getRowIds() public long bytesPerPosting() { long REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; - return REF_BYTES + Integer.BYTES; + return REF_BYTES + Integer.BYTES + 1; // 1 byte for boolean } } @@ -203,8 +213,25 @@ static class Marshaller implements BytesReader, BytesW public void write(Bytes out, CompactionVectorPostings postings) { out.writeInt(postings.ordinal); out.writeInt(postings.size()); - for (Integer posting : postings.getPostings()) { - out.writeInt(posting); + out.writeBoolean(postings.shouldCompress); + if (postings.shouldCompress) + { + try + { + BytesDataOutput writer = new BytesDataOutput(out); + for (int posting : postings.getPostings()) + writer.writeVInt(posting); + } + catch (IOException e) + { + // Not be reachable because the Bytes out object does not throw an exception on write. + throw new RuntimeException(e); + } + } + else + { + for (Integer posting : postings.getPostings()) + out.writeInt(posting); } } @@ -213,20 +240,89 @@ public CompactionVectorPostings read(Bytes in, CompactionVectorPostings using) { int ordinal = in.readInt(); int size = in.readInt(); assert size >= 0 : size; - CompactionVectorPostings cvp; - if (size == 1) { - cvp = new CompactionVectorPostings(ordinal, in.readInt()); + boolean isCompressed = in.readBoolean(); + if (isCompressed) + { + try + { + BytesDataInput reader = new BytesDataInput(in); + var postingsList = new IntArrayList(size, -1); + for (int i = 0; i < size; i++) + postingsList.add(reader.readVInt()); + + return new CompactionVectorPostings(ordinal, postingsList); + } + catch (IOException e) + { + // Not be reachable because the Bytes in object does not throw an exception on read. + throw new RuntimeException(e); + } } else { - var postingsList = new IntArrayList(size, -1); - for (int i = 0; i < size; i++) + if (size == 1) { - postingsList.add(in.readInt()); + return new CompactionVectorPostings(ordinal, in.readInt()); } - cvp = new CompactionVectorPostings(ordinal, postingsList); + else + { + var postingsList = new IntArrayList(size, -1); + for (int i = 0; i < size; i++) + postingsList.add(in.readInt()); + + return new CompactionVectorPostings(ordinal, postingsList); + } + } + } + + private static class BytesDataOutput extends DataOutput + { + private final Bytes bytes; + + public BytesDataOutput(Bytes bytes) + { + this.bytes = bytes; + } + + @Override + public void writeByte(byte b) + { + bytes.writeByte(b); + } + + @Override + public void writeBytes(byte[] b, int off, int len) + { + bytes.write(b, off, len); + } + } + + private static class BytesDataInput extends DataInput + { + private final Bytes bytes; + + public BytesDataInput(Bytes bytes) + { + this.bytes = bytes; + } + + @Override + public byte readByte() + { + return bytes.readByte(); + } + + @Override + public void readBytes(byte[] b, int off, int len) + { + bytes.read(b, off, len); + } + + @Override + public void skipBytes(long l) + { + bytes.readSkip(l); } - return cvp; } } } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/vector/CompactionGraphTest.java b/test/unit/org/apache/cassandra/index/sai/disk/vector/CompactionGraphTest.java new file mode 100644 index 000000000000..bc0eaa0351cc --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/disk/vector/CompactionGraphTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.vector; + +import java.util.List; + +import org.junit.Test; + +import io.github.jbellis.jvector.util.RamUsageEstimator; +import net.openhft.chronicle.map.ChronicleMapBuilder; +import org.apache.cassandra.io.util.File; +import org.apache.cassandra.io.util.FileUtils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class CompactionGraphTest +{ + + @Test + public void test10Entries() throws Exception + { + // 10 entries means we hit the limit sooner + testEntries(10, 1000, 1); + } + + @Test + public void test1MEntries() throws Exception + { + // more entries means it takes longer, but we hit this bug in prod before, so it is worth testing + testEntries(1000000, 5000, 100); + } + + @Test + public void test50MEntries() throws Exception + { + // more entries means it takes longer, but we hit this bug in prod before, so it is worth testing + testEntries(50000000, 5000, 100); + } + + // Callers of this method are expected to provide enough iterations and postings added per iteration + // to hit the entry size limit without exceeding it too much. Note that we add postings one at a time in the + // compaction graph, so we only ever increment by 4 bytes each time we attempt to re-serialize the entry. + private void testEntries(int entries, int iterations, int postingsAddedPerIteration) throws Exception + { + File postingsFile = FileUtils.createTempFile("testfile", "tmp"); + try(var postingsMap = ChronicleMapBuilder.of(Integer.class, (Class) (Class) VectorPostings.CompactionVectorPostings.class) + .averageValueSize(VectorPostings.emptyBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 2 * Integer.BYTES) + .valueMarshaller(new VectorPostings.Marshaller()) + .entries(entries) + .createPersistedTo(postingsFile.toJavaIOFile())) + { + postingsMap.put(0, new VectorPostings.CompactionVectorPostings(0, List.of())); + int rowId = 0; + boolean recoveredFromFailure = false; + for (int i = 0; i < iterations; i++) + { + // Iterate so we can fail sooner + var existing = postingsMap.get(0); + for (int j = 0; j < postingsAddedPerIteration; j++) + existing.add(rowId++); + + recoveredFromFailure = CompactionGraph.safePut(postingsMap, 0, existing); + if (recoveredFromFailure) + break; + } + + assertTrue("Failed to hit entry size limit", recoveredFromFailure); + // Validate that we can read (deserialize) the entry + var existing = postingsMap.get(0); + assertNotNull(existing); + assertEquals(rowId, existing.size()); + // Validate that the row ids are correct + var rowIds = existing.getRowIds(); + for (int i = 0; i < rowId; i++) + assertEquals(i, (int) rowIds.get(i)); + } + } +}