@@ -178,6 +178,8 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
178178 var started = System .nanoTime ();
179179 var fieldInfo = field .fieldInfo ;
180180
181+ CagraIndexParams cagraIndexParams = createCagraIndexParams (fieldInfo .getVectorSimilarityFunction ());
182+
181183 var numVectors = field .flatFieldVectorsWriter .getVectors ().size ();
182184 if (numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
183185 if (logger .isDebugEnabled ()) {
@@ -193,12 +195,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
193195 try (
194196 var resourcesHolder = new ResourcesHolder (
195197 cuVSResourceManager ,
196- cuVSResourceManager .acquire (
197- numVectors ,
198- fieldInfo .getVectorDimension (),
199- CuVSMatrix .DataType .FLOAT ,
200- CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT
201- )
198+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), CuVSMatrix .DataType .FLOAT , cagraIndexParams )
202199 )
203200 ) {
204201 var builder = CuVSMatrix .deviceBuilder (
@@ -211,7 +208,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
211208 builder .addVector (vector );
212209 }
213210 try (var dataset = builder .build ()) {
214- flushFieldWithGpuGraph (resourcesHolder , fieldInfo , dataset , sortMap );
211+ flushFieldWithGpuGraph (resourcesHolder , fieldInfo , dataset , sortMap , cagraIndexParams );
215212 }
216213 }
217214 }
@@ -229,13 +226,18 @@ private void flushFieldWithMockGraph(FieldInfo fieldInfo, int numVectors, Sorter
229226 }
230227 }
231228
232- private void flushFieldWithGpuGraph (ResourcesHolder resourcesHolder , FieldInfo fieldInfo , CuVSMatrix dataset , Sorter .DocMap sortMap )
233- throws IOException {
229+ private void flushFieldWithGpuGraph (
230+ ResourcesHolder resourcesHolder ,
231+ FieldInfo fieldInfo ,
232+ CuVSMatrix dataset ,
233+ Sorter .DocMap sortMap ,
234+ CagraIndexParams cagraIndexParams
235+ ) throws IOException {
234236 if (sortMap == null ) {
235- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
237+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
236238 } else {
237239 // TODO: use sortMap
238- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
240+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
239241 }
240242 }
241243
@@ -267,14 +269,19 @@ public long ramBytesUsed() {
267269 return total ;
268270 }
269271
270- private void generateGpuGraphAndWriteMeta (ResourcesHolder resourcesHolder , FieldInfo fieldInfo , CuVSMatrix dataset ) throws IOException {
272+ private void generateGpuGraphAndWriteMeta (
273+ ResourcesHolder resourcesHolder ,
274+ FieldInfo fieldInfo ,
275+ CuVSMatrix dataset ,
276+ CagraIndexParams cagraIndexParams
277+ ) throws IOException {
271278 try {
272279 assert dataset .size () >= MIN_NUM_VECTORS_FOR_GPU_BUILD ;
273280
274281 long vectorIndexOffset = vectorIndex .getFilePointer ();
275282 int [][] graphLevelNodeOffsets = new int [1 ][];
276283 final HnswGraph graph ;
277- try (var index = buildGPUIndex (resourcesHolder .resources (), fieldInfo . getVectorSimilarityFunction () , dataset )) {
284+ try (var index = buildGPUIndex (resourcesHolder .resources (), cagraIndexParams , dataset )) {
278285 assert index != null : "GPU index should be built for field: " + fieldInfo .name ;
279286 var deviceGraph = index .getGraph ();
280287 var graphSize = deviceGraph .size () * deviceGraph .columns () * Integer .BYTES ;
@@ -314,9 +321,20 @@ private void generateMockGraphAndWriteMeta(FieldInfo fieldInfo, int datasetSize)
314321
315322 private CagraIndex buildGPUIndex (
316323 CuVSResourceManager .ManagedCuVSResources cuVSResources ,
317- VectorSimilarityFunction similarityFunction ,
324+ CagraIndexParams cagraIndexParams ,
318325 CuVSMatrix dataset
319326 ) throws Throwable {
327+ long startTime = System .nanoTime ();
328+ var indexBuilder = CagraIndex .newBuilder (cuVSResources ).withDataset (dataset ).withIndexParams (cagraIndexParams );
329+ var index = indexBuilder .build ();
330+ cuVSResourceManager .finishedComputation (cuVSResources );
331+ if (logger .isDebugEnabled ()) {
332+ logger .debug ("Carga index created in: {} ms; #num vectors: {}" , (System .nanoTime () - startTime ) / 1_000_000.0 , dataset .size ());
333+ }
334+ return index ;
335+ }
336+
337+ private CagraIndexParams createCagraIndexParams (VectorSimilarityFunction similarityFunction ) {
320338 CagraIndexParams .CuvsDistanceType distanceType = switch (similarityFunction ) {
321339 case COSINE -> CagraIndexParams .CuvsDistanceType .CosineExpanded ;
322340 case EUCLIDEAN -> CagraIndexParams .CuvsDistanceType .L2Expanded ;
@@ -333,22 +351,13 @@ private CagraIndex buildGPUIndex(
333351 };
334352
335353 // TODO: expose cagra index params for algorithm, NNDescentNumIterations
336- CagraIndexParams params = new CagraIndexParams .Builder ().withNumWriterThreads (1 ) // TODO: how many CPU threads we can use?
354+ return new CagraIndexParams .Builder ().withNumWriterThreads (1 ) // TODO: how many CPU threads we can use?
337355 .withCagraGraphBuildAlgo (CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT )
338356 .withGraphDegree (M )
339357 .withIntermediateGraphDegree (beamWidth )
340358 .withNNDescentNumIterations (5 )
341359 .withMetric (distanceType )
342360 .build ();
343-
344- long startTime = System .nanoTime ();
345- var indexBuilder = CagraIndex .newBuilder (cuVSResources ).withDataset (dataset ).withIndexParams (params );
346- var index = indexBuilder .build ();
347- cuVSResourceManager .finishedComputation (cuVSResources );
348- if (logger .isDebugEnabled ()) {
349- logger .debug ("Carga index created in: {} ms; #num vectors: {}" , (System .nanoTime () - startTime ) / 1_000_000.0 , dataset .size ());
350- }
351- return index ;
352361 }
353362
354363 private HnswGraph writeGraph (CuVSMatrix cagraGraph , int [][] levelNodeOffsets ) throws IOException {
@@ -505,6 +514,9 @@ private void mergeByteVectorField(
505514 var vectorValues = randomScorerSupplier == null
506515 ? null
507516 : VectorsFormatReflectionUtils .getByteScoringSupplierVectorOrNull (randomScorerSupplier );
517+
518+ CagraIndexParams cagraIndexParams = createCagraIndexParams (fieldInfo .getVectorSimilarityFunction ());
519+
508520 if (vectorValues != null ) {
509521 IndexInput slice = vectorValues .getSlice ();
510522 var input = FilterIndexInput .unwrapOnlyTest (slice );
@@ -538,15 +550,10 @@ private void mergeByteVectorField(
538550 var dataset = DatasetUtilsImpl .fromMemorySegment (packedSegment , numVectors , packedRowSize , dataType );
539551 var resourcesHolder = new ResourcesHolder (
540552 cuVSResourceManager ,
541- cuVSResourceManager .acquire (
542- numVectors ,
543- fieldInfo .getVectorDimension (),
544- dataType ,
545- CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT
546- )
553+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType , cagraIndexParams )
547554 )
548555 ) {
549- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
556+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
550557 }
551558 }
552559 } else {
@@ -567,15 +574,10 @@ private void mergeByteVectorField(
567574 var dataset = builder .build ();
568575 var resourcesHolder = new ResourcesHolder (
569576 cuVSResourceManager ,
570- cuVSResourceManager .acquire (
571- numVectors ,
572- fieldInfo .getVectorDimension (),
573- dataType ,
574- CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT
575- )
577+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType , cagraIndexParams )
576578 )
577579 ) {
578- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
580+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
579581 }
580582 }
581583 } else {
@@ -593,15 +595,10 @@ private void mergeByteVectorField(
593595 var dataset = builder .build ();
594596 var resourcesHolder = new ResourcesHolder (
595597 cuVSResourceManager ,
596- cuVSResourceManager .acquire (
597- numVectors ,
598- fieldInfo .getVectorDimension (),
599- dataType ,
600- CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT
601- )
598+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType , cagraIndexParams )
602599 )
603600 ) {
604- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
601+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
605602 }
606603 }
607604 }
@@ -615,6 +612,8 @@ private void mergeFloatVectorField(
615612 var vectorValues = randomScorerSupplier == null
616613 ? null
617614 : VectorsFormatReflectionUtils .getFloatScoringSupplierVectorOrNull (randomScorerSupplier );
615+ CagraIndexParams cagraIndexParams = createCagraIndexParams (fieldInfo .getVectorSimilarityFunction ());
616+
618617 if (vectorValues != null ) {
619618 IndexInput slice = vectorValues .getSlice ();
620619 var input = FilterIndexInput .unwrapOnlyTest (slice );
@@ -625,15 +624,10 @@ private void mergeFloatVectorField(
625624 .fromInput (memorySegmentAccessInput , numVectors , fieldInfo .getVectorDimension (), dataType );
626625 var resourcesHolder = new ResourcesHolder (
627626 cuVSResourceManager ,
628- cuVSResourceManager .acquire (
629- numVectors ,
630- fieldInfo .getVectorDimension (),
631- dataType ,
632- CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT
633- )
627+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType , cagraIndexParams )
634628 )
635629 ) {
636- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
630+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
637631 }
638632 } else {
639633 logger .info (
@@ -653,15 +647,10 @@ private void mergeFloatVectorField(
653647 var dataset = builder .build ();
654648 var resourcesHolder = new ResourcesHolder (
655649 cuVSResourceManager ,
656- cuVSResourceManager .acquire (
657- numVectors ,
658- fieldInfo .getVectorDimension (),
659- dataType ,
660- CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT
661- )
650+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType , cagraIndexParams )
662651 )
663652 ) {
664- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
653+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
665654 }
666655 }
667656 } else {
@@ -680,15 +669,10 @@ private void mergeFloatVectorField(
680669 var dataset = builder .build ();
681670 var resourcesHolder = new ResourcesHolder (
682671 cuVSResourceManager ,
683- cuVSResourceManager .acquire (
684- numVectors ,
685- fieldInfo .getVectorDimension (),
686- dataType ,
687- CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT
688- )
672+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType , cagraIndexParams )
689673 )
690674 ) {
691- generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
675+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset , cagraIndexParams );
692676 }
693677 }
694678 }
0 commit comments