Skip to content

Commit 30062fc

Browse files
committed
Adding new bbq_ivf format behind a feature flag
1 parent ba50798 commit 30062fc

File tree

5 files changed

+296
-35
lines changed

5 files changed

+296
-35
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,24 @@ public class IVFVectorsFormat extends KnnVectorsFormat {
6161
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
6262
);
6363

64-
private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000;
64+
public static final int DYNAMIC_NPROBE = -1;
65+
public static final int DEFAULT_VECTORS_PER_CLUSTER = 384;
66+
public static final int MIN_VECTORS_PER_CLUSTER = 64;
67+
public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536
6568

6669
private final int vectorPerCluster;
6770

6871
public IVFVectorsFormat(int vectorPerCluster) {
6972
super(NAME);
70-
if (vectorPerCluster <= 0) {
71-
throw new IllegalArgumentException("vectorPerCluster must be > 0");
73+
if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) {
74+
throw new IllegalArgumentException(
75+
"vectorsPerCluster must be between "
76+
+ MIN_VECTORS_PER_CLUSTER
77+
+ " and "
78+
+ MAX_VECTORS_PER_CLUSTER
79+
+ ", got: "
80+
+ vectorPerCluster
81+
);
7282
}
7383
this.vectorPerCluster = vectorPerCluster;
7484
}
@@ -90,12 +100,12 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
90100

91101
@Override
92102
public int getMaxDimensions(String fieldName) {
93-
return 1024;
103+
return 4096;
94104
}
95105

96106
@Override
97107
public String toString() {
98-
return "IVFVectorFormat";
108+
return "IVFVectorsFormat(" + "vectorPerCluster=" + vectorPerCluster + ')';
99109
}
100110

101111
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.function.IntPredicate;
3939

4040
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
41+
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE;
4142

4243
/**
4344
* Reader for IVF vectors. This reader is used to read the IVF vectors from the index.
@@ -226,17 +227,6 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti
226227
return rawVectorsReader.getByteVectorValues(field);
227228
}
228229

229-
protected float[] getGlobalCentroid(FieldInfo info) {
230-
if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) {
231-
return null;
232-
}
233-
FieldEntry entry = fields.get(info.number);
234-
if (entry == null) {
235-
return null;
236-
}
237-
return entry.globalCentroid();
238-
}
239-
240230
@Override
241231
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
242232
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
@@ -261,12 +251,9 @@ public final void search(String field, float[] target, KnnCollector knnCollector
261251
}
262252
return visitedDocs.getAndSet(docId) == false;
263253
};
264-
final int nProbe;
254+
int nProbe = DYNAMIC_NPROBE;
265255
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
266256
nProbe = ivfSearchStrategy.getNProbe();
267-
} else {
268-
// TODO calculate nProbe given the number of centroids vs. number of vectors for given `k`
269-
nProbe = 10;
270257
}
271258

272259
FieldEntry entry = fields.get(fieldInfo.number);
@@ -277,17 +264,27 @@ public final void search(String field, float[] target, KnnCollector knnCollector
277264
target,
278265
ivfClusters
279266
);
267+
if (nProbe == DYNAMIC_NPROBE) {
268+
// empirically based, and a good dynamic to get decent recall while scaling a la "efSearch"
269+
// scaling by the number of centroids vs. the nearest neighbors requested
270+
// not perfect, but a comparative heuristic.
271+
// we might want to utilize the total vector count as well, but this is a good start
272+
nProbe = (int) Math.round(Math.log10(centroidQueryScorer.size()) * Math.sqrt(knnCollector.k()));
273+
// clip to be between 1 and the number of centroids
274+
nProbe = Math.max(Math.min(nProbe, centroidQueryScorer.size()), 1);
275+
}
280276
final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);
281277
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
282278
int centroidsVisited = 0;
283279
long expectedDocs = 0;
284280
long actualDocs = 0;
285281
// initially we visit only the "centroids to search"
286-
while (centroidQueue.size() > 0 && centroidsVisited < nProbe) {
282+
while (centroidQueue.size() > 0 && centroidsVisited < nProbe && actualDocs < knnCollector.k()) {
287283
++centroidsVisited;
288284
// todo do we actually need to know the score???
289285
int centroidOrdinal = centroidQueue.pop();
290-
// todo do we need direct access to the raw centroid???
286+
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
287+
// is enough?
291288
expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
292289
actualDocs += scorer.visit(knnCollector);
293290
}

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 138 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.lucene.util.VectorUtil;
4040
import org.elasticsearch.common.ParsingException;
4141
import org.elasticsearch.common.settings.Setting;
42+
import org.elasticsearch.common.util.FeatureFlag;
4243
import org.elasticsearch.common.xcontent.support.XContentMapValues;
4344
import org.elasticsearch.features.NodeFeature;
4445
import org.elasticsearch.index.IndexVersion;
@@ -48,6 +49,7 @@
4849
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
4950
import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
5051
import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat;
52+
import org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
5153
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
5254
import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
5355
import org.elasticsearch.index.fielddata.FieldDataContext;
@@ -62,6 +64,7 @@
6264
import org.elasticsearch.index.mapper.Mapper;
6365
import org.elasticsearch.index.mapper.MapperBuilderContext;
6466
import org.elasticsearch.index.mapper.MapperParsingException;
67+
import org.elasticsearch.index.mapper.MappingLookup;
6568
import org.elasticsearch.index.mapper.MappingParser;
6669
import org.elasticsearch.index.mapper.NumberFieldMapper;
6770
import org.elasticsearch.index.mapper.SimpleMappedFieldType;
@@ -78,6 +81,7 @@
7881
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
7982
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
8083
import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
84+
import org.elasticsearch.search.vectors.IVFKnnFloatVectorQuery;
8185
import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;
8286
import org.elasticsearch.search.vectors.VectorData;
8387
import org.elasticsearch.search.vectors.VectorSimilarityQuery;
@@ -106,6 +110,8 @@
106110
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_INDEX_VERSION_CREATED;
107111
import static org.elasticsearch.common.Strings.format;
108112
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
113+
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER;
114+
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MIN_VECTORS_PER_CLUSTER;
109115

110116
/**
111117
* A {@link FieldMapper} for indexing a dense vector of floats.
@@ -115,6 +121,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
115121
private static final float EPS = 1e-3f;
116122
public static final int BBQ_MIN_DIMS = 64;
117123

124+
public static final FeatureFlag IVF_FORMAT = new FeatureFlag("ivf_format");
125+
118126
public static boolean isNotUnitVector(float magnitude) {
119127
return Math.abs(magnitude - 1.0f) > EPS;
120128
}
@@ -1594,14 +1602,63 @@ public boolean supportsElementType(ElementType elementType) {
15941602
return elementType == ElementType.FLOAT;
15951603
}
15961604

1605+
@Override
1606+
public boolean supportsDimension(int dims) {
1607+
return dims >= BBQ_MIN_DIMS;
1608+
}
1609+
},
1610+
BBQ_IVF("bbq_ivf", true) {
1611+
@Override
1612+
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap, IndexVersion indexVersion) {
1613+
Object clusterSizeNode = indexOptionsMap.remove("cluster_size");
1614+
int clusterSize = IVFVectorsFormat.DEFAULT_VECTORS_PER_CLUSTER;
1615+
if (clusterSizeNode != null) {
1616+
clusterSize = XContentMapValues.nodeIntegerValue(clusterSizeNode);
1617+
if (clusterSize < MIN_VECTORS_PER_CLUSTER || clusterSize > MAX_VECTORS_PER_CLUSTER) {
1618+
throw new IllegalArgumentException(
1619+
"cluster_size must be between "
1620+
+ MIN_VECTORS_PER_CLUSTER
1621+
+ " and "
1622+
+ MAX_VECTORS_PER_CLUSTER
1623+
+ ", got: "
1624+
+ clusterSize
1625+
);
1626+
}
1627+
}
1628+
RescoreVector rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion);
1629+
if (rescoreVector == null) {
1630+
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
1631+
}
1632+
Object nProbeNode = indexOptionsMap.remove("default_n_probe");
1633+
int nProbe = -1;
1634+
if (nProbeNode != null) {
1635+
nProbe = XContentMapValues.nodeIntegerValue(nProbeNode);
1636+
if (nProbe < 1 && nProbe != -1) {
1637+
throw new IllegalArgumentException(
1638+
"default_n_probe must be at least 1 or exactly -1, got: " + nProbe + " for field [" + fieldName + "]"
1639+
);
1640+
}
1641+
}
1642+
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
1643+
return new BBQIVFIndexOptions(clusterSize, nProbe, rescoreVector);
1644+
}
1645+
1646+
@Override
1647+
public boolean supportsElementType(ElementType elementType) {
1648+
return elementType == ElementType.FLOAT;
1649+
}
1650+
15971651
@Override
15981652
public boolean supportsDimension(int dims) {
15991653
return dims >= BBQ_MIN_DIMS;
16001654
}
16011655
};
16021656

16031657
static Optional<VectorIndexType> fromString(String type) {
1604-
return Stream.of(VectorIndexType.values()).filter(vectorIndexType -> vectorIndexType.name.equals(type)).findFirst();
1658+
return Stream.of(VectorIndexType.values())
1659+
.filter(vectorIndexType -> vectorIndexType != VectorIndexType.BBQ_IVF || IVF_FORMAT.isEnabled())
1660+
.filter(vectorIndexType -> vectorIndexType.name.equals(type))
1661+
.findFirst();
16051662
}
16061663

16071664
private final String name;
@@ -2100,6 +2157,54 @@ public boolean validateDimension(int dim, boolean throwOnError) {
21002157

21012158
}
21022159

2160+
static class BBQIVFIndexOptions extends QuantizedIndexOptions {
2161+
final int clusterSize;
2162+
final int defaultNProbe;
2163+
2164+
BBQIVFIndexOptions(int clusterSize, int defaultNProbe, RescoreVector rescoreVector) {
2165+
super(VectorIndexType.BBQ_IVF, rescoreVector);
2166+
this.clusterSize = clusterSize;
2167+
this.defaultNProbe = defaultNProbe;
2168+
}
2169+
2170+
@Override
2171+
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
2172+
assert elementType == ElementType.FLOAT;
2173+
return new IVFVectorsFormat(clusterSize);
2174+
}
2175+
2176+
@Override
2177+
boolean updatableTo(IndexOptions update) {
2178+
return update.type.equals(this.type);
2179+
}
2180+
2181+
@Override
2182+
boolean doEquals(IndexOptions other) {
2183+
BBQIVFIndexOptions that = (BBQIVFIndexOptions) other;
2184+
return clusterSize == that.clusterSize
2185+
&& defaultNProbe == that.defaultNProbe
2186+
&& Objects.equals(rescoreVector, that.rescoreVector);
2187+
}
2188+
2189+
@Override
2190+
int doHashCode() {
2191+
return Objects.hash(clusterSize, defaultNProbe, rescoreVector);
2192+
}
2193+
2194+
@Override
2195+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
2196+
builder.startObject();
2197+
builder.field("type", type);
2198+
builder.field("cluster_size", clusterSize);
2199+
builder.field("default_n_probe", defaultNProbe);
2200+
if (rescoreVector != null) {
2201+
rescoreVector.toXContent(builder, params);
2202+
}
2203+
builder.endObject();
2204+
return builder;
2205+
}
2206+
}
2207+
21032208
public record RescoreVector(float oversample) implements ToXContentObject {
21042209
static final String NAME = "rescore_vector";
21052210
static final String OVERSAMPLE = "oversample";
@@ -2411,17 +2516,25 @@ && isNotUnitVector(squaredMagnitude)) {
24112516
adjustedK = Math.min((int) Math.ceil(k * oversample), OVERSAMPLE_LIMIT);
24122517
numCands = Math.max(adjustedK, numCands);
24132518
}
2414-
Query knnQuery = parentFilter != null
2415-
? new ESDiversifyingChildrenFloatKnnVectorQuery(
2416-
name(),
2417-
queryVector,
2418-
filter,
2419-
adjustedK,
2420-
numCands,
2421-
parentFilter,
2422-
knnSearchStrategy
2423-
)
2424-
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
2519+
if (parentFilter != null && indexOptions instanceof BBQIVFIndexOptions) {
2520+
throw new IllegalArgumentException("IVF index does not support nested queries");
2521+
}
2522+
Query knnQuery;
2523+
if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
2524+
knnQuery = new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, filter, bbqIndexOptions.defaultNProbe);
2525+
} else {
2526+
knnQuery = parentFilter != null
2527+
? new ESDiversifyingChildrenFloatKnnVectorQuery(
2528+
name(),
2529+
queryVector,
2530+
filter,
2531+
adjustedK,
2532+
numCands,
2533+
parentFilter,
2534+
knnSearchStrategy
2535+
)
2536+
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
2537+
}
24252538
if (rescore) {
24262539
knnQuery = new RescoreKnnVectorQuery(
24272540
name(),
@@ -2651,6 +2764,19 @@ public FieldMapper.Builder getMergeBuilder() {
26512764
return new Builder(leafName(), indexCreatedVersion).init(this);
26522765
}
26532766

2767+
@Override
2768+
public void doValidate(MappingLookup mappers) {
2769+
if (indexOptions instanceof BBQIVFIndexOptions && mappers.nestedLookup().getNestedParent(fullPath()) != null) {
2770+
throw new IllegalArgumentException(
2771+
"["
2772+
+ CONTENT_TYPE
2773+
+ "] fields with index type ["
2774+
+ indexOptions.type
2775+
+ "] with cannot be indexed if they're within [nested] mappings"
2776+
);
2777+
}
2778+
}
2779+
26542780
private static IndexOptions parseIndexOptions(String fieldName, Object propNode, IndexVersion indexVersion) {
26552781
@SuppressWarnings("unchecked")
26562782
Map<String, ?> indexOptionsMap = (Map<String, ?>) propNode;

server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
1212

1313
import org.apache.lucene.codecs.Codec;
14+
import org.apache.lucene.codecs.FilterCodec;
1415
import org.apache.lucene.codecs.KnnVectorsFormat;
1516
import org.apache.lucene.index.VectorEncoding;
1617
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -20,6 +21,13 @@
2021
import org.junit.Before;
2122

2223
import java.util.List;
24+
import java.util.Locale;
25+
26+
import static java.lang.String.format;
27+
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER;
28+
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MIN_VECTORS_PER_CLUSTER;
29+
import static org.hamcrest.Matchers.is;
30+
import static org.hamcrest.Matchers.oneOf;
2331

2432
public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
2533

@@ -32,7 +40,7 @@ public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
3240
@Before
3341
@Override
3442
public void setUp() throws Exception {
35-
format = new IVFVectorsFormat(random().nextInt(10, 1000));
43+
format = new IVFVectorsFormat(random().nextInt(MIN_VECTORS_PER_CLUSTER, IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER));
3644
super.setUp();
3745
}
3846

@@ -62,4 +70,28 @@ public void testSearchWithVisitedLimit() {
6270
protected Codec getCodec() {
6371
return TestUtil.alwaysKnnVectorsFormat(format);
6472
}
73+
74+
@Override
75+
public void testAdvance() throws Exception {
76+
// TODO re-enable with hierarchical IVF, clustering as it is is flaky
77+
}
78+
79+
public void testToString() {
80+
FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) {
81+
@Override
82+
public KnnVectorsFormat knnVectorsFormat() {
83+
return new IVFVectorsFormat(128);
84+
}
85+
};
86+
String expectedPattern = "IVFVectorsFormat(vectorPerCluster=128)";
87+
88+
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
89+
var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
90+
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
91+
}
92+
93+
public void testLimits() {
94+
expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MIN_VECTORS_PER_CLUSTER - 1));
95+
expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MAX_VECTORS_PER_CLUSTER + 1));
96+
}
6597
}

0 commit comments

Comments
 (0)