1616import org .apache .lucene .codecs .KnnVectorsWriter ;
1717import org .apache .lucene .codecs .hnsw .FlatFieldVectorsWriter ;
1818import org .apache .lucene .codecs .hnsw .FlatVectorsWriter ;
19+ import org .apache .lucene .codecs .lucene99 .Lucene99FlatVectorsWriter ;
1920import org .apache .lucene .index .ByteVectorValues ;
2021import org .apache .lucene .index .DocsWithFieldSet ;
2122import org .apache .lucene .index .FieldInfo ;
4243import org .elasticsearch .index .codec .vectors .ES814ScalarQuantizedVectorsFormat ;
4344import org .elasticsearch .logging .LogManager ;
4445import org .elasticsearch .logging .Logger ;
46+ import org .elasticsearch .xpack .gpu .reflect .VectorsFormatReflectionUtils ;
4547
4648import java .io .IOException ;
4749import java .nio .ByteBuffer ;
@@ -73,6 +75,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
7375 private final CuVSResourceManager cuVSResourceManager ;
7476 private final SegmentWriteState segmentWriteState ;
7577 private final IndexOutput meta , vectorIndex ;
78+ private final IndexOutput vectorData ;
7679 private final int M ;
7780 private final int beamWidth ;
7881 private final FlatVectorsWriter flatVectorWriter ;
@@ -94,8 +97,11 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
9497 this .beamWidth = beamWidth ;
9598 this .flatVectorWriter = flatVectorWriter ;
9699 if (flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat .ES814ScalarQuantizedVectorsWriter ) {
100+ vectorData = VectorsFormatReflectionUtils .getQuantizedVectorDataIndexOutput (flatVectorWriter );
97101 dataType = CuVSMatrix .DataType .BYTE ;
98102 } else {
103+ assert flatVectorWriter instanceof Lucene99FlatVectorsWriter ;
104+ vectorData = VectorsFormatReflectionUtils .getVectorDataIndexOutput (flatVectorWriter );
99105 dataType = CuVSMatrix .DataType .FLOAT ;
100106 }
101107 this .segmentWriteState = state ;
@@ -148,11 +154,38 @@ public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException
148154 @ Override
149155 public void flush (int maxDoc , Sorter .DocMap sortMap ) throws IOException {
150156 flatVectorWriter .flush (maxDoc , sortMap );
151- for (FieldWriter field : fields ) {
152- if (sortMap == null ) {
153- writeField (field );
154- } else {
155- writeSortingField (field , sortMap );
157+
158+ try (IndexInput in = segmentWriteState .segmentInfo .dir .openInput (vectorData .getName (), IOContext .DEFAULT )) {
159+ var input = FilterIndexInput .unwrapOnlyTest (in );
160+
161+ for (FieldWriter fieldWriter : fields ) {
162+ // TODO: is this inefficient? Can we get "size" in another way?
163+ var numVectors = fieldWriter .flatFieldVectorsWriter .getVectors ().size ();
164+
165+ final DatasetOrVectors datasetOrVectors ;
166+ if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
167+ // TODO: we are iterating over multiple fields, we probably need to memorySegmentAccessInput.segmentSliceOrNull()?
168+ var ds = DatasetUtils .getInstance ()
169+ .fromInput (memorySegmentAccessInput , numVectors , fieldWriter .fieldInfo .getVectorDimension (), dataType );
170+ datasetOrVectors = DatasetOrVectors .fromDataset (ds );
171+ } else {
172+ var builder = CuVSMatrix .hostBuilder (numVectors , fieldWriter .fieldInfo .getVectorDimension (), dataType );
173+ for (var vector : fieldWriter .flatFieldVectorsWriter .getVectors ()) {
174+ builder .addVector (vector );
175+ }
176+
177+ datasetOrVectors = DatasetOrVectors .fromDataset (builder .build ());
178+ }
179+
180+ try {
181+ if (sortMap == null ) {
182+ writeField (fieldWriter .fieldInfo , datasetOrVectors );
183+ } else {
184+ writeSortingField (fieldWriter .fieldInfo , datasetOrVectors , sortMap );
185+ }
186+ } finally {
187+ datasetOrVectors .close ();
188+ }
156189 }
157190 }
158191 }
@@ -221,38 +254,13 @@ public void close() {
221254 }
222255 }
223256
224- private void writeField (FieldWriter fieldWriter ) throws IOException {
225- var vectors = fieldWriter .flatFieldVectorsWriter .getVectors ();
226- final DatasetOrVectors datasetOrVectors ;
227- if (vectors .size () < MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
228- // Use vectors/CPU
229- datasetOrVectors = DatasetOrVectors .fromArray (vectors );
230- } else {
231- // Avoid another heap copy (the float[][])
232-
233- // TODO: another alternative is to use CuVSMatrix.deviceBuilder(), but this requires more effort
234- // 1. support no-copy CuVSDeviceMatrix as input in CagraIndex
235- // 2. ensure we are already holding a CuVSResource here
236- var builder = CuVSMatrix .hostBuilder (vectors .size (), vectors .getFirst ().length , dataType );
237- for (var vector : vectors ) {
238- builder .addVector (vector );
239- }
240- datasetOrVectors = DatasetOrVectors .fromDataset (builder .build ());
241- }
242- try {
243- writeFieldInternal (fieldWriter .fieldInfo , datasetOrVectors );
244- } finally {
245- datasetOrVectors .close ();
246- }
247- }
248-
249- private void writeSortingField (FieldWriter fieldData , Sorter .DocMap sortMap ) throws IOException {
257+ private void writeSortingField (FieldInfo fieldInfo , DatasetOrVectors datasetOrVectors , Sorter .DocMap sortMap ) throws IOException {
250258 // The flatFieldVectorsWriter's flush method, called before this, has already sorted the vectors according to the sortMap.
251259 // We can now treat them as a simple, sorted list of vectors.
252- writeField (fieldData );
260+ writeField (fieldInfo , datasetOrVectors );
253261 }
254262
255- private void writeFieldInternal (FieldInfo fieldInfo , DatasetOrVectors datasetOrVectors ) throws IOException {
263+ private void writeField (FieldInfo fieldInfo , DatasetOrVectors datasetOrVectors ) throws IOException {
256264 try {
257265 long vectorIndexOffset = vectorIndex .getFilePointer ();
258266 int [][] graphLevelNodeOffsets = new int [1 ][];
0 commit comments