3232import org .apache .lucene .util .VectorUtil ;
3333import org .elasticsearch .core .IOUtils ;
3434import org .elasticsearch .core .SuppressForbidden ;
35+ import org .elasticsearch .index .codec .vectors .cluster .PrefetchingFloatVectorValues ;
3536
3637import java .io .IOException ;
3738import java .io .UncheckedIOException ;
@@ -120,8 +121,11 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
120121 return rawVectorDelegate ;
121122 }
122123
123- abstract CentroidAssignments calculateCentroids (FieldInfo fieldInfo , FloatVectorValues floatVectorValues , float [] globalCentroid )
124- throws IOException ;
124+ abstract CentroidAssignments calculateCentroids (
125+ FieldInfo fieldInfo ,
126+ PrefetchingFloatVectorValues floatVectorValues ,
127+ float [] globalCentroid
128+ ) throws IOException ;
125129
126130 abstract void writeCentroids (
127131 FieldInfo fieldInfo ,
@@ -134,7 +138,7 @@ abstract void writeCentroids(
134138 abstract LongValues buildAndWritePostingsLists (
135139 FieldInfo fieldInfo ,
136140 CentroidSupplier centroidSupplier ,
137- FloatVectorValues floatVectorValues ,
141+ PrefetchingFloatVectorValues floatVectorValues ,
138142 IndexOutput postingsOutput ,
139143 long fileOffset ,
140144 int [] assignments ,
@@ -144,7 +148,7 @@ abstract LongValues buildAndWritePostingsLists(
144148 abstract LongValues buildAndWritePostingsLists (
145149 FieldInfo fieldInfo ,
146150 CentroidSupplier centroidSupplier ,
147- FloatVectorValues floatVectorValues ,
151+ PrefetchingFloatVectorValues floatVectorValues ,
148152 IndexOutput postingsOutput ,
149153 long fileOffset ,
150154 MergeState mergeState ,
@@ -165,7 +169,11 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
165169 for (FieldWriter fieldWriter : fieldWriters ) {
166170 final float [] globalCentroid = new float [fieldWriter .fieldInfo .getVectorDimension ()];
167171 // build a float vector values with random access
168- final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldWriter .fieldInfo , fieldWriter .delegate , maxDoc );
172+ final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues (
173+ fieldWriter .fieldInfo ,
174+ fieldWriter .delegate ,
175+ maxDoc
176+ );
169177 // build centroids
170178 final CentroidAssignments centroidAssignments = calculateCentroids (fieldWriter .fieldInfo , floatVectorValues , globalCentroid );
171179 // wrap centroids with a supplier
@@ -199,47 +207,22 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
199207 }
200208 }
201209
202- private static FloatVectorValues getFloatVectorValues (
210+ private static PrefetchingFloatVectorValues getFloatVectorValues (
203211 FieldInfo fieldInfo ,
204212 FlatFieldVectorsWriter <float []> fieldVectorsWriter ,
205213 int maxDoc
206214 ) throws IOException {
207215 List <float []> vectors = fieldVectorsWriter .getVectors ();
208216 if (vectors .size () == maxDoc ) {
209- return FloatVectorValues . fromFloats (vectors , fieldInfo .getVectorDimension ());
217+ return PrefetchingFloatVectorValues . floats (vectors , fieldInfo .getVectorDimension ());
210218 }
211219 final DocIdSetIterator iterator = fieldVectorsWriter .getDocsWithFieldSet ().iterator ();
212220 final int [] docIds = new int [vectors .size ()];
213221 for (int i = 0 ; i < docIds .length ; i ++) {
214222 docIds [i ] = iterator .nextDoc ();
215223 }
216224 assert iterator .nextDoc () == NO_MORE_DOCS ;
217- return new FloatVectorValues () {
218- @ Override
219- public float [] vectorValue (int ord ) {
220- return vectors .get (ord );
221- }
222-
223- @ Override
224- public FloatVectorValues copy () {
225- return this ;
226- }
227-
228- @ Override
229- public int dimension () {
230- return fieldInfo .getVectorDimension ();
231- }
232-
233- @ Override
234- public int size () {
235- return vectors .size ();
236- }
237-
238- @ Override
239- public int ordToDoc (int ord ) {
240- return docIds [ord ];
241- }
242- };
225+ return PrefetchingFloatVectorValues .floats (vectors , fieldInfo .getVectorDimension (), docIds );
243226 }
244227
245228 @ Override
@@ -297,7 +280,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
297280 IndexInput vectors = mergeState .segmentInfo .dir .openInput (tempRawVectorsFileName , IOContext .DEFAULT );
298281 IndexInput docs = docsFileName == null ? null : mergeState .segmentInfo .dir .openInput (docsFileName , IOContext .DEFAULT )
299282 ) {
300- final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , docs , vectors , numVectors );
283+ final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , docs , vectors , numVectors );
301284
302285 final long centroidOffset ;
303286 final long centroidLength ;
@@ -396,15 +379,26 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
396379 }
397380 }
398381
399- private static FloatVectorValues getFloatVectorValues (FieldInfo fieldInfo , IndexInput docs , IndexInput vectors , int numVectors )
400- throws IOException {
382+ private static PrefetchingFloatVectorValues getFloatVectorValues (
383+ FieldInfo fieldInfo ,
384+ IndexInput docs ,
385+ IndexInput vectors ,
386+ int numVectors
387+ ) throws IOException {
401388 if (numVectors == 0 ) {
402- return FloatVectorValues . fromFloats (List .of (), fieldInfo .getVectorDimension ());
389+ return PrefetchingFloatVectorValues . floats (List .of (), fieldInfo .getVectorDimension ());
403390 }
404391 final long vectorLength = (long ) Float .BYTES * fieldInfo .getVectorDimension ();
405392 final float [] vector = new float [fieldInfo .getVectorDimension ()];
406393 final RandomAccessInput randomDocs = docs == null ? null : docs .randomAccessSlice (0 , docs .length ());
407- return new FloatVectorValues () {
394+ return new PrefetchingFloatVectorValues () {
395+ @ Override
396+ public void prefetch (int ... ord ) throws IOException {
397+ for (int o : ord ) {
398+ vectors .prefetch (o * vectorLength , vectorLength );
399+ }
400+ }
401+
408402 @ Override
409403 public float [] vectorValue (int ord ) throws IOException {
410404 vectors .seek (ord * vectorLength );
@@ -413,7 +407,8 @@ public float[] vectorValue(int ord) throws IOException {
413407 }
414408
415409 @ Override
416- public FloatVectorValues copy () {
410+ public PrefetchingFloatVectorValues copy () {
411+ assert false ;
417412 return this ;
418413 }
419414
0 commit comments