Skip to content

Commit 019a241

Browse files
authored
Parallelize graph writes (#542)
* adding async file channel usage * update regression tests to use parallel writes * batching based on dataset size
1 parent 02fea87 commit 019a241

File tree

11 files changed

+1808
-25
lines changed

11 files changed

+1808
-25
lines changed

benchmarks-jmh/scripts/test_node_setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ java --enable-native-access=ALL-UNNAMED \
4343
--add-modules=jdk.incubator.vector \
4444
-XX:+HeapDumpOnOutOfMemoryError \
4545
-Xmx14G -Djvector.experimental.enable_native_vectorization=true \
46-
-jar target/benchmarks-jmh-4.0.0-beta.3-SNAPSHOT.jar
46+
-jar target/benchmarks-jmh-4.0.0-rc.4-SNAPSHOT.jar
4747

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.github.jbellis.jvector.bench;
17+
18+
import io.github.jbellis.jvector.disk.ReaderSupplierFactory;
19+
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
20+
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
21+
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
22+
import io.github.jbellis.jvector.graph.NodesIterator;
23+
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
24+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
25+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
26+
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
27+
import io.github.jbellis.jvector.graph.disk.feature.Feature;
28+
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
29+
import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
30+
import io.github.jbellis.jvector.graph.disk.feature.NVQ;
31+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
32+
import io.github.jbellis.jvector.quantization.NVQuantization;
33+
import io.github.jbellis.jvector.quantization.PQVectors;
34+
import io.github.jbellis.jvector.quantization.ProductQuantization;
35+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
36+
import io.github.jbellis.jvector.vector.VectorizationProvider;
37+
import io.github.jbellis.jvector.vector.types.VectorFloat;
38+
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
39+
import org.openjdk.jmh.annotations.*;
40+
import org.openjdk.jmh.infra.Blackhole;
41+
42+
import java.io.IOException;
43+
import java.nio.file.Files;
44+
import java.nio.file.Path;
45+
import java.util.*;
46+
import java.util.concurrent.TimeUnit;
47+
import java.util.concurrent.atomic.AtomicInteger;
48+
import java.util.function.IntFunction;
49+
50+
import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED;
51+
52+
/**
53+
* JMH benchmark that mirrors the ParallelWriteExample: it builds a graph from vectors, then
54+
* writes the graph to disk sequentially and in parallel using NVQ + FUSED_PQ features,
55+
* and verifies that the outputs are identical.
56+
*/
57+
@BenchmarkMode(Mode.AverageTime)
58+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
59+
@State(Scope.Benchmark)
60+
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"})
61+
@Warmup(iterations = 1)
62+
@Measurement(iterations = 2)
63+
@Threads(1)
64+
public class ParallelWriteBenchmark {
65+
private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport();
66+
67+
@Param({"100000"})
68+
int numBaseVectors;
69+
70+
@Param({"1024"})
71+
int dimension;
72+
73+
@Param({"true", "false"})
74+
boolean addHierarchy;
75+
76+
// Graph build parameters
77+
final int M = 32;
78+
final int efConstruction = 100;
79+
final float neighborOverflow = 1.2f;
80+
final float alpha = 1.2f;
81+
//final boolean addHierarchy = false;
82+
final boolean refineFinalGraph = true;
83+
84+
// Dataset and index state
85+
private RandomAccessVectorValues floatVectors;
86+
private PQVectors pqVectors;
87+
private ImmutableGraphIndex graph;
88+
89+
// Feature state reused between iterations
90+
private NVQ nvqFeature;
91+
private FusedPQ fusedPQFeature;
92+
private OrdinalMapper identityMapper;
93+
private Map<FeatureId, IntFunction<Feature.State>> inlineSuppliers;
94+
95+
// Paths
96+
private Path tempDir;
97+
private final AtomicInteger fileCounter = new AtomicInteger();
98+
99+
@Setup(Level.Trial)
100+
public void setup() throws IOException {
101+
// Generate random vectors
102+
final var baseVectors = new ArrayList<VectorFloat<?>>(numBaseVectors);
103+
for (int i = 0; i < numBaseVectors; i++) {
104+
baseVectors.add(createRandomVector(dimension));
105+
}
106+
floatVectors = new ListRandomAccessVectorValues(baseVectors, dimension);
107+
108+
// Compute PQ compression
109+
final int pqM = Math.max(1, dimension / 8);
110+
final boolean centerData = true; // for EUCLIDEAN
111+
final var pq = ProductQuantization.compute(floatVectors, pqM, 256, centerData, UNWEIGHTED);
112+
pqVectors = (PQVectors) pq.encodeAll(floatVectors);
113+
114+
// Build graph using PQ build score provider
115+
final var bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqVectors);
116+
try (var builder = new GraphIndexBuilder(bsp, floatVectors.dimension(), M, efConstruction,
117+
neighborOverflow, alpha, addHierarchy, refineFinalGraph)) {
118+
graph = builder.build(floatVectors);
119+
}
120+
121+
// Prepare features
122+
int nSubVectors = floatVectors.dimension() == 2 ? 1 : 2;
123+
var nvq = NVQuantization.compute(floatVectors, nSubVectors);
124+
nvqFeature = new NVQ(nvq);
125+
fusedPQFeature = new FusedPQ(graph.maxDegree(), pqVectors.getCompressor());
126+
127+
inlineSuppliers = new EnumMap<>(FeatureId.class);
128+
inlineSuppliers.put(FeatureId.NVQ_VECTORS, ordinal -> new NVQ.State(nvq.encode(floatVectors.getVector(ordinal))));
129+
130+
identityMapper = new OrdinalMapper.IdentityMapper(floatVectors.size() - 1);
131+
132+
// Temp directory for outputs
133+
tempDir = Files.createTempDirectory("parallel-write-bench");
134+
}
135+
136+
@TearDown(Level.Trial)
137+
public void tearDown() throws IOException {
138+
if (tempDir != null) {
139+
// Best-effort cleanup of all files created
140+
try (var stream = Files.list(tempDir)) {
141+
stream.forEach(p -> {
142+
try { Files.deleteIfExists(p); } catch (IOException ignored) {}
143+
});
144+
}
145+
Files.deleteIfExists(tempDir);
146+
}
147+
}
148+
149+
@Benchmark
150+
public void writeSequentialThenParallelAndVerify(Blackhole blackhole) throws IOException {
151+
// Unique output files per invocation
152+
int idx = fileCounter.getAndIncrement();
153+
Path sequentialPath = tempDir.resolve("graph-sequential-" + idx);
154+
Path parallelPath = tempDir.resolve("graph-parallel-" + idx);
155+
156+
long startSeq = System.nanoTime();
157+
writeGraph(graph, sequentialPath, false);
158+
long seqTime = System.nanoTime() - startSeq;
159+
160+
long startPar = System.nanoTime();
161+
writeGraph(graph, parallelPath, true);
162+
long parTime = System.nanoTime() - startPar;
163+
164+
// Report times and speedup for this invocation
165+
double seqMs = seqTime / 1_000_000.0;
166+
double parMs = parTime / 1_000_000.0;
167+
double speedup = parTime == 0 ? Double.NaN : seqTime / (double) parTime;
168+
System.out.printf("Sequential write: %.2f ms, Parallel write: %.2f ms, Speedup: %.2fx%n", seqMs, parMs, speedup);
169+
170+
// Load and verify identical
171+
OnDiskGraphIndex sequentialIndex = OnDiskGraphIndex.load(ReaderSupplierFactory.open(sequentialPath));
172+
OnDiskGraphIndex parallelIndex = OnDiskGraphIndex.load(ReaderSupplierFactory.open(parallelPath));
173+
try {
174+
verifyIndicesIdentical(sequentialIndex, parallelIndex);
175+
} finally {
176+
sequentialIndex.close();
177+
parallelIndex.close();
178+
}
179+
180+
// Consume sizes to prevent DCE
181+
blackhole.consume(Files.size(sequentialPath));
182+
blackhole.consume(Files.size(parallelPath));
183+
184+
// Cleanup files after each invocation to limit disk usage
185+
Files.deleteIfExists(sequentialPath);
186+
Files.deleteIfExists(parallelPath);
187+
}
188+
189+
private void writeGraph(ImmutableGraphIndex graph,
190+
Path path,
191+
boolean parallel) throws IOException {
192+
try (var writer = new OnDiskGraphIndexWriter.Builder(graph, path)
193+
.withParallelWrites(parallel)
194+
.with(nvqFeature)
195+
.with(fusedPQFeature)
196+
.withMapper(identityMapper)
197+
.build()) {
198+
var view = graph.getView();
199+
Map<FeatureId, IntFunction<Feature.State>> writeSuppliers = new EnumMap<>(FeatureId.class);
200+
writeSuppliers.put(FeatureId.NVQ_VECTORS, inlineSuppliers.get(FeatureId.NVQ_VECTORS));
201+
writeSuppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(view, pqVectors, ordinal));
202+
203+
writer.write(writeSuppliers);
204+
view.close();
205+
}
206+
}
207+
208+
private static void verifyIndicesIdentical(OnDiskGraphIndex index1, OnDiskGraphIndex index2) throws IOException {
209+
// Basic properties
210+
if (index1.getMaxLevel() != index2.getMaxLevel()) {
211+
throw new AssertionError("Max levels differ: " + index1.getMaxLevel() + " vs " + index2.getMaxLevel());
212+
}
213+
if (index1.getIdUpperBound() != index2.getIdUpperBound()) {
214+
throw new AssertionError("ID upper bounds differ: " + index1.getIdUpperBound() + " vs " + index2.getIdUpperBound());
215+
}
216+
if (!index1.getFeatureSet().equals(index2.getFeatureSet())) {
217+
throw new AssertionError("Feature sets differ: " + index1.getFeatureSet() + " vs " + index2.getFeatureSet());
218+
}
219+
220+
try (var view1 = index1.getView(); var view2 = index2.getView()) {
221+
if (!view1.entryNode().equals(view2.entryNode())) {
222+
throw new AssertionError("Entry nodes differ: " + view1.entryNode() + " vs " + view2.entryNode());
223+
}
224+
for (int level = 0; level <= index1.getMaxLevel(); level++) {
225+
if (index1.size(level) != index2.size(level)) {
226+
throw new AssertionError("Layer " + level + " sizes differ: " + index1.size(level) + " vs " + index2.size(level));
227+
}
228+
if (index1.getDegree(level) != index2.getDegree(level)) {
229+
throw new AssertionError("Layer " + level + " degrees differ: " + index1.getDegree(level) + " vs " + index2.getDegree(level));
230+
}
231+
232+
// Collect node IDs in arrays
233+
java.util.List<Integer> nodeList1 = new java.util.ArrayList<>();
234+
java.util.List<Integer> nodeList2 = new java.util.ArrayList<>();
235+
NodesIterator nodes1 = index1.getNodes(level);
236+
while (nodes1.hasNext()) nodeList1.add(nodes1.nextInt());
237+
NodesIterator nodes2 = index2.getNodes(level);
238+
while (nodes2.hasNext()) nodeList2.add(nodes2.nextInt());
239+
if (!nodeList1.equals(nodeList2)) {
240+
throw new AssertionError("Layer " + level + " has different node sets");
241+
}
242+
243+
// Compare neighbors
244+
for (int nodeId : nodeList1) {
245+
NodesIterator neighbors1 = view1.getNeighborsIterator(level, nodeId);
246+
NodesIterator neighbors2 = view2.getNeighborsIterator(level, nodeId);
247+
if (neighbors1.size() != neighbors2.size()) {
248+
throw new AssertionError("Layer " + level + " node " + nodeId + " neighbor counts differ: " + neighbors1.size() + " vs " + neighbors2.size());
249+
}
250+
int[] n1 = new int[neighbors1.size()];
251+
int[] n2 = new int[neighbors2.size()];
252+
for (int i = 0; i < n1.length; i++) {
253+
n1[i] = neighbors1.nextInt();
254+
n2[i] = neighbors2.nextInt();
255+
}
256+
if (!Arrays.equals(n1, n2)) {
257+
throw new AssertionError("Layer " + level + " node " + nodeId + " has different neighbor sets");
258+
}
259+
}
260+
}
261+
262+
// Optional vector checks (layer 0)
263+
if (index1.getFeatureSet().contains(FeatureId.INLINE_VECTORS) ||
264+
index1.getFeatureSet().contains(FeatureId.NVQ_VECTORS)) {
265+
int vectorsChecked = 0;
266+
int maxToCheck = Math.min(100, index1.size(0));
267+
NodesIterator nodes = index1.getNodes(0);
268+
while (nodes.hasNext() && vectorsChecked < maxToCheck) {
269+
int node = nodes.nextInt();
270+
if (index1.getFeatureSet().contains(FeatureId.INLINE_VECTORS)) {
271+
var vec1 = view1.getVector(node);
272+
var vec2 = view2.getVector(node);
273+
if (!vec1.equals(vec2)) {
274+
throw new AssertionError("Node " + node + " vectors differ");
275+
}
276+
}
277+
vectorsChecked++;
278+
}
279+
}
280+
}
281+
}
282+
283+
private VectorFloat<?> createRandomVector(int dimension) {
284+
VectorFloat<?> vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension);
285+
for (int i = 0; i < dimension; i++) {
286+
vector.set(i, (float) Math.random());
287+
}
288+
return vector;
289+
}
290+
}

0 commit comments

Comments
 (0)