@@ -71,6 +71,7 @@ final class ES92GpuHnswVectorsWriter extends KnnVectorsWriter {
7171 private static final Logger logger = LogManager .getLogger (ES92GpuHnswVectorsWriter .class );
7272 private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator .shallowSizeOfInstance (ES92GpuHnswVectorsWriter .class );
7373 private static final int LUCENE99_HNSW_DIRECT_MONOTONIC_BLOCK_SHIFT = 16 ;
74+ private static final long DIRECT_COPY_THRESHOLD_IN_BYTES = 128 * 1024 * 1024 ; // 128MB
7475
7576 private final CuVSResourceManager cuVSResourceManager ;
7677 private final SegmentWriteState segmentWriteState ;
@@ -191,10 +192,14 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
191192 // Will not be indexed on the GPU
192193 flushFieldWithMockGraph (fieldInfo , numVectors , sortMap );
193194 } else {
194- var cuVSResources = cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), CuVSMatrix .DataType .FLOAT );
195- try {
195+ try (
196+ var resourcesHolder = new ResourcesHolder (
197+ cuVSResourceManager ,
198+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), CuVSMatrix .DataType .FLOAT )
199+ )
200+ ) {
196201 var builder = CuVSMatrix .deviceBuilder (
197- cuVSResources ,
202+ resourcesHolder . resources () ,
198203 numVectors ,
199204 fieldInfo .getVectorDimension (),
200205 CuVSMatrix .DataType .FLOAT
@@ -203,10 +208,8 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
203208 builder .addVector (vector );
204209 }
205210 try (var dataset = builder .build ()) {
206- flushFieldWithGpuGraph (cuVSResources , fieldInfo , dataset , sortMap );
211+ flushFieldWithGpuGraph (resourcesHolder , fieldInfo , dataset , sortMap );
207212 }
208- } finally {
209- cuVSResourceManager .release (cuVSResources );
210213 }
211214 }
212215 var elapsed = started - System .nanoTime ();
@@ -223,17 +226,13 @@ private void flushFieldWithMockGraph(FieldInfo fieldInfo, int numVectors, Sorter
223226 }
224227 }
225228
226- private void flushFieldWithGpuGraph (
227- CuVSResourceManager .ManagedCuVSResources resources ,
228- FieldInfo fieldInfo ,
229- CuVSMatrix dataset ,
230- Sorter .DocMap sortMap
231- ) throws IOException {
229+ private void flushFieldWithGpuGraph (ResourcesHolder resourcesHolder , FieldInfo fieldInfo , CuVSMatrix dataset , Sorter .DocMap sortMap )
230+ throws IOException {
232231 if (sortMap == null ) {
233- generateGpuGraphAndWriteMeta (resources , fieldInfo , dataset );
232+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
234233 } else {
235234 // TODO: use sortMap
236- generateGpuGraphAndWriteMeta (resources , fieldInfo , dataset );
235+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
237236 }
238237 }
239238
@@ -265,20 +264,25 @@ public long ramBytesUsed() {
265264 return total ;
266265 }
267266
268- private void generateGpuGraphAndWriteMeta (
269- CuVSResourceManager .ManagedCuVSResources cuVSResources ,
270- FieldInfo fieldInfo ,
271- CuVSMatrix dataset
272- ) throws IOException {
267+ private void generateGpuGraphAndWriteMeta (ResourcesHolder resourcesHolder , FieldInfo fieldInfo , CuVSMatrix dataset ) throws IOException {
273268 try {
274269 assert dataset .size () >= MIN_NUM_VECTORS_FOR_GPU_BUILD ;
275270
276271 long vectorIndexOffset = vectorIndex .getFilePointer ();
277272 int [][] graphLevelNodeOffsets = new int [1 ][];
278273 final HnswGraph graph ;
279- try (var index = buildGPUIndex (cuVSResources , fieldInfo .getVectorSimilarityFunction (), dataset )) {
274+ try (var index = buildGPUIndex (resourcesHolder . resources () , fieldInfo .getVectorSimilarityFunction (), dataset )) {
280275 assert index != null : "GPU index should be built for field: " + fieldInfo .name ;
281- graph = writeGraph (index .getGraph (), graphLevelNodeOffsets );
276+ var deviceGraph = index .getGraph ();
277+ var graphSize = deviceGraph .size () * deviceGraph .columns () * Integer .BYTES ;
278+ if (graphSize < DIRECT_COPY_THRESHOLD_IN_BYTES ) {
279+ try (var hostGraph = deviceGraph .toHost ()) {
280+ resourcesHolder .close ();
281+ graph = writeGraph (hostGraph , graphLevelNodeOffsets );
282+ }
283+ } else {
284+ graph = writeGraph (deviceGraph , graphLevelNodeOffsets );
285+ }
282286 }
283287 long vectorIndexLength = vectorIndex .getFilePointer () - vectorIndexOffset ;
284288 writeMeta (fieldInfo , vectorIndexOffset , vectorIndexLength , (int ) dataset .size (), graph , graphLevelNodeOffsets );
@@ -479,25 +483,35 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
479483 if (numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
480484 if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput ) {
481485 // Direct access to mmapped file
482- final var dataset = DatasetUtils .getInstance ()
483- .fromInput (memorySegmentAccessInput , numVectors , fieldInfo .getVectorDimension (), dataType );
484-
485- var cuVSResources = cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType );
486- try {
487- generateGpuGraphAndWriteMeta (cuVSResources , fieldInfo , dataset );
488- } finally {
489- dataset .close ();
490- cuVSResourceManager .release (cuVSResources );
486+
487+ try (
488+ var dataset = DatasetUtils .getInstance ()
489+ .fromInput (memorySegmentAccessInput , numVectors , fieldInfo .getVectorDimension (), dataType );
490+ var resourcesHolder = new ResourcesHolder (
491+ cuVSResourceManager ,
492+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType )
493+ )
494+ ) {
495+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
491496 }
492497 } else {
493498 logger .debug (
494499 () -> "Cannot mmap merged raw vectors temporary file. IndexInput type [" + input .getClass ().getSimpleName () + "]"
495500 );
496501
497- var cuVSResources = cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType );
498- try {
502+ try (
503+ var resourcesHolder = new ResourcesHolder (
504+ cuVSResourceManager ,
505+ cuVSResourceManager .acquire (numVectors , fieldInfo .getVectorDimension (), dataType )
506+ )
507+ ) {
499508 // Read vector-by-vector
500- var builder = CuVSMatrix .deviceBuilder (cuVSResources , numVectors , fieldInfo .getVectorDimension (), dataType );
509+ var builder = CuVSMatrix .deviceBuilder (
510+ resourcesHolder .resources (),
511+ numVectors ,
512+ fieldInfo .getVectorDimension (),
513+ dataType
514+ );
501515
502516 // During merging, we use quantized data, so we need to support byte[] too.
503517 // That's how our current formats work: use floats during indexing, and quantized data to build a graph
@@ -517,10 +531,8 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
517531 }
518532 }
519533 try (var dataset = builder .build ()) {
520- generateGpuGraphAndWriteMeta (cuVSResources , fieldInfo , dataset );
534+ generateGpuGraphAndWriteMeta (resourcesHolder , fieldInfo , dataset );
521535 }
522- } finally {
523- cuVSResourceManager .release (cuVSResources );
524536 }
525537 }
526538 } else {
0 commit comments