Skip to content

Commit a53e8e7

Browse files
authored
LUCENE-9615: Expose HnswGraphBuilder index-time hyperparameters as FieldType attributes (from Shubham Beniwal))
1 parent 8f75933 commit a53e8e7

File tree

6 files changed

+133
-30
lines changed

6 files changed

+133
-30
lines changed

lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOExcep
123123
(RandomAccessVectorValuesProducer) vectors,
124124
vectorIndexOffset,
125125
offsets,
126-
count);
126+
count,
127+
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY),
128+
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY));
127129
} else {
128130
throw new IllegalArgumentException(
129131
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
@@ -188,9 +190,35 @@ private void writeGraph(
188190
RandomAccessVectorValuesProducer vectorValues,
189191
long graphDataOffset,
190192
long[] offsets,
191-
int count)
193+
int count,
194+
String maxConnStr,
195+
String beamWidthStr)
192196
throws IOException {
193-
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(vectorValues);
197+
int maxConn, beamWidth;
198+
if (maxConnStr == null) {
199+
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
200+
} else {
201+
try {
202+
maxConn = Integer.parseInt(maxConnStr);
203+
} catch (NumberFormatException e) {
204+
throw new NumberFormatException(
205+
"Received non integer value for max-connections parameter of HnswGraphBuilder, value: "
206+
+ maxConnStr);
207+
}
208+
}
209+
if (beamWidthStr == null) {
210+
beamWidth = HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
211+
} else {
212+
try {
213+
beamWidth = Integer.parseInt(beamWidthStr);
214+
} catch (NumberFormatException e) {
215+
throw new NumberFormatException(
216+
"Received non integer value for beam-width parameter of HnswGraphBuilder, value: "
217+
+ beamWidthStr);
218+
}
219+
}
220+
HnswGraphBuilder hnswGraphBuilder =
221+
new HnswGraphBuilder(vectorValues, maxConn, beamWidth, System.currentTimeMillis());
194222
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
195223
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
196224

lucene/core/src/java/org/apache/lucene/document/VectorField.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.lucene.document;
1919

2020
import org.apache.lucene.index.VectorValues;
21+
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
2122

2223
/**
2324
* A field that contains a single floating-point numeric vector (or none) for each document. Vectors
@@ -32,7 +33,7 @@
3233
*/
3334
public class VectorField extends Field {
3435

35-
private static FieldType getType(float[] v, VectorValues.SearchStrategy searchStrategy) {
36+
private static FieldType createType(float[] v, VectorValues.SearchStrategy searchStrategy) {
3637
if (v == null) {
3738
throw new IllegalArgumentException("vector value must not be null");
3839
}
@@ -53,6 +54,37 @@ private static FieldType getType(float[] v, VectorValues.SearchStrategy searchSt
5354
return type;
5455
}
5556

57+
/**
58+
* Public method to create HNSW field type with the given max-connections and beam-width
59+
* parameters that would be used by HnswGraphBuilder while constructing HNSW graph.
60+
*
61+
* @param dimension dimension of vectors
62+
* @param searchStrategy a function defining vector proximity.
63+
* @param maxConn max-connections at each HNSW graph node
64+
* @param beamWidth size of list to be used while constructing HNSW graph
65+
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
66+
*/
67+
public static FieldType createHnswType(
68+
int dimension, VectorValues.SearchStrategy searchStrategy, int maxConn, int beamWidth) {
69+
if (dimension == 0) {
70+
throw new IllegalArgumentException("cannot index an empty vector");
71+
}
72+
if (dimension > VectorValues.MAX_DIMENSIONS) {
73+
throw new IllegalArgumentException(
74+
"cannot index vectors with dimension greater than " + VectorValues.MAX_DIMENSIONS);
75+
}
76+
if (searchStrategy == null || !searchStrategy.isHnsw()) {
77+
throw new IllegalArgumentException(
78+
"search strategy must not be null or non HNSW type, received: " + searchStrategy);
79+
}
80+
FieldType type = new FieldType();
81+
type.setVectorDimensionsAndSearchStrategy(dimension, searchStrategy);
82+
type.putAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY, String.valueOf(maxConn));
83+
type.putAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY, String.valueOf(beamWidth));
84+
type.freeze();
85+
return type;
86+
}
87+
5688
/**
5789
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
5890
* no value. Vectors of a single field share the same dimension and search strategy. Note that
@@ -66,7 +98,7 @@ private static FieldType getType(float[] v, VectorValues.SearchStrategy searchSt
6698
* dimension > 1024.
6799
*/
68100
public VectorField(String name, float[] vector, VectorValues.SearchStrategy searchStrategy) {
69-
super(name, getType(vector, searchStrategy));
101+
super(name, createType(vector, searchStrategy));
70102
fieldsData = vector;
71103
}
72104

@@ -84,6 +116,21 @@ public VectorField(String name, float[] vector) {
84116
this(name, vector, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
85117
}
86118

119+
/**
120+
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
121+
* no value. Vectors of a single field share the same dimension and search strategy.
122+
*
123+
* @param name field name
124+
* @param vector value
125+
* @param fieldType field type
126+
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
127+
* dimension > 1024.
128+
*/
129+
public VectorField(String name, float[] vector, FieldType fieldType) {
130+
super(name, fieldType);
131+
fieldsData = vector;
132+
}
133+
87134
/** Return the vector value of this field */
88135
public float[] vectorValue() {
89136
return (float[]) fieldsData;

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@ public final class HnswGraphBuilder {
4343
*/
4444

4545
// default max connections per node
46-
public static int DEFAULT_MAX_CONN = 16;
46+
public static final int DEFAULT_MAX_CONN = 16;
47+
public static String HNSW_MAX_CONN_ATTRIBUTE_KEY = "max_connections";
4748

4849
// default candidate list size
49-
public static int DEFAULT_BEAM_WIDTH = 16;
50+
public static final int DEFAULT_BEAM_WIDTH = 16;
51+
public static String HNSW_BEAM_WIDTH_ATTRIBUTE_KEY = "beam_width";
5052

5153
private final int maxConn;
5254
private final int beamWidth;

lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
3131
import org.apache.lucene.document.Document;
3232
import org.apache.lucene.document.Field;
33+
import org.apache.lucene.document.FieldType;
3334
import org.apache.lucene.document.SortedDocValuesField;
3435
import org.apache.lucene.document.StringField;
3536
import org.apache.lucene.document.VectorField;
@@ -58,16 +59,15 @@ public class TestKnnGraph extends LuceneTestCase {
5859
public void setup() {
5960
randSeed = random().nextLong();
6061
if (random().nextBoolean()) {
61-
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
62-
HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(256) + 2;
62+
maxConn = random().nextInt(256) + 2;
6363
}
6464
int strategy = random().nextInt(SearchStrategy.values().length - 1) + 1;
6565
searchStrategy = SearchStrategy.values()[strategy];
6666
}
6767

6868
@After
6969
public void cleanup() {
70-
HnswGraphBuilder.DEFAULT_MAX_CONN = maxConn;
70+
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
7171
}
7272

7373
/** Basic test of creating documents in a graph */
@@ -196,7 +196,7 @@ private float[][] randomVectors(int numDoc, int dimension) {
196196
int[][] copyGraph(KnnGraphValues values) throws IOException {
197197
int size = values.size();
198198
int[][] graph = new int[size][];
199-
int[] scratch = new int[HnswGraphBuilder.DEFAULT_MAX_CONN];
199+
int[] scratch = new int[maxConn];
200200
for (int node = 0; node < size; node++) {
201201
int n, count = 0;
202202
values.seek(node);
@@ -368,12 +368,12 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values) throws IOEx
368368
assertTrue(
369369
"Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
370370
}
371-
if (HnswGraphBuilder.DEFAULT_MAX_CONN > graphSize) {
371+
if (maxConn > graphSize) {
372372
// assert that the graph in each leaf is connected
373373
assertConnected(graph);
374374
} else {
375375
// assert that max-connections was respected
376-
assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN);
376+
assertMaxConn(graph, maxConn);
377377
}
378378
totalGraphDocs += graphSize;
379379
}
@@ -439,7 +439,10 @@ private void add(IndexWriter iw, int id, float[] vector, SearchStrategy searchSt
439439
throws IOException {
440440
Document doc = new Document();
441441
if (vector != null) {
442-
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, searchStrategy));
442+
FieldType fieldType =
443+
VectorField.createHnswType(
444+
vector.length, searchStrategy, maxConn, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
445+
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, fieldType));
443446
}
444447
String idString = Integer.toString(id);
445448
doc.add(new StringField("id", idString, Field.Store.YES));

lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.lucene.document.Document;
2626
import org.apache.lucene.document.Field;
2727
import org.apache.lucene.document.Field.Store;
28+
import org.apache.lucene.document.FieldType;
2829
import org.apache.lucene.document.NumericDocValuesField;
2930
import org.apache.lucene.document.StringField;
3031
import org.apache.lucene.document.VectorField;
@@ -76,11 +77,16 @@ public void testFieldConstructor() {
7677
public void testFieldConstructorExceptions() {
7778
expectThrows(IllegalArgumentException.class, () -> new VectorField(null, new float[1]));
7879
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", null));
79-
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[1], null));
80+
expectThrows(
81+
IllegalArgumentException.class,
82+
() -> new VectorField("f", new float[1], (SearchStrategy) null));
8083
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[0]));
8184
expectThrows(
8285
IllegalArgumentException.class,
8386
() -> new VectorField("f", new float[VectorValues.MAX_DIMENSIONS + 1]));
87+
expectThrows(
88+
IllegalArgumentException.class,
89+
() -> new VectorField("f", new float[VectorValues.MAX_DIMENSIONS + 1], (FieldType) null));
8490
}
8591

8692
public void testFieldSetValue() {
@@ -92,6 +98,25 @@ public void testFieldSetValue() {
9298
expectThrows(IllegalArgumentException.class, () -> field.setVectorValue(null));
9399
}
94100

101+
public void testFieldCreateFieldType() {
102+
expectThrows(
103+
IllegalArgumentException.class,
104+
() -> VectorField.createHnswType(0, SearchStrategy.EUCLIDEAN_HNSW, 16, 16));
105+
expectThrows(
106+
IllegalArgumentException.class,
107+
() ->
108+
VectorField.createHnswType(
109+
VectorValues.MAX_DIMENSIONS + 1, SearchStrategy.EUCLIDEAN_HNSW, 16, 16));
110+
expectThrows(
111+
IllegalArgumentException.class,
112+
() -> VectorField.createHnswType(VectorValues.MAX_DIMENSIONS + 1, null, 16, 16));
113+
expectThrows(
114+
IllegalArgumentException.class,
115+
() ->
116+
VectorField.createHnswType(
117+
VectorValues.MAX_DIMENSIONS + 1, SearchStrategy.NONE, 16, 16));
118+
}
119+
95120
// Illegal schema change tests:
96121

97122
public void testIllegalDimChangeTwoDocs() throws Exception {

lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.util.Set;
3838
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
3939
import org.apache.lucene.document.Document;
40+
import org.apache.lucene.document.FieldType;
4041
import org.apache.lucene.document.StoredField;
4142
import org.apache.lucene.document.VectorField;
4243
import org.apache.lucene.index.CodecReader;
@@ -83,6 +84,8 @@ public class KnnGraphTester {
8384
private boolean reindex;
8485
private boolean forceMerge;
8586
private int reindexTimeMsec;
87+
private int beamWidth;
88+
private int maxConn;
8689

8790
@SuppressForbidden(reason = "uses Random()")
8891
private KnnGraphTester() {
@@ -132,13 +135,13 @@ private void run(String... args) throws Exception {
132135
if (iarg == args.length - 1) {
133136
throw new IllegalArgumentException("-beamWidthIndex requires a following number");
134137
}
135-
HnswGraphBuilder.DEFAULT_BEAM_WIDTH = Integer.parseInt(args[++iarg]);
138+
beamWidth = Integer.parseInt(args[++iarg]);
136139
break;
137140
case "-maxConn":
138141
if (iarg == args.length - 1) {
139142
throw new IllegalArgumentException("-maxConn requires a following number");
140143
}
141-
HnswGraphBuilder.DEFAULT_MAX_CONN = Integer.parseInt(args[++iarg]);
144+
maxConn = Integer.parseInt(args[++iarg]);
142145
break;
143146
case "-dim":
144147
if (iarg == args.length - 1) {
@@ -223,12 +226,7 @@ private void run(String... args) throws Exception {
223226
}
224227

225228
private String formatIndexPath(Path docsPath) {
226-
return docsPath.getFileName()
227-
+ "-"
228-
+ HnswGraphBuilder.DEFAULT_MAX_CONN
229-
+ "-"
230-
+ HnswGraphBuilder.DEFAULT_BEAM_WIDTH
231-
+ ".index";
229+
return docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index";
232230
}
233231

234232
@SuppressForbidden(reason = "Prints stuff")
@@ -250,9 +248,7 @@ private void printFanoutHist(Path indexPath) throws IOException {
250248
private void dumpGraph(Path docsPath) throws IOException {
251249
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
252250
RandomAccessVectorValues values = vectors.randomAccess();
253-
HnswGraphBuilder builder =
254-
new HnswGraphBuilder(
255-
vectors, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, 0);
251+
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, 0);
256252
// start at node 1
257253
for (int i = 1; i < numDocs; i++) {
258254
builder.addGraphNode(values.vectorValue(i));
@@ -413,8 +409,8 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][]
413409
totalCpuTime / (float) numIters,
414410
numDocs,
415411
fanout,
416-
HnswGraphBuilder.DEFAULT_MAX_CONN,
417-
HnswGraphBuilder.DEFAULT_BEAM_WIDTH,
412+
maxConn,
413+
beamWidth,
418414
totalVisited,
419415
reindexTimeMsec);
420416
}
@@ -574,6 +570,9 @@ private int createIndex(Path docsPath, Path indexPath) throws IOException {
574570
iwc.setRAMBufferSizeMB(1994d);
575571
// iwc.setMaxBufferedDocs(10000);
576572

573+
FieldType fieldType =
574+
VectorField.createHnswType(
575+
dim, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, maxConn, beamWidth);
577576
if (quiet == false) {
578577
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
579578
System.out.println("creating index in " + indexPath);
@@ -598,8 +597,7 @@ private int createIndex(Path docsPath, Path indexPath) throws IOException {
598597
vectors.get(vector);
599598
Document doc = new Document();
600599
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
601-
doc.add(
602-
new VectorField(KNN_FIELD, vector, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW));
600+
doc.add(new VectorField(KNN_FIELD, vector, fieldType));
603601
doc.add(new StoredField(ID_FIELD, i));
604602
iw.addDocument(doc);
605603
}

0 commit comments

Comments
 (0)