3030import org .apache .lucene .store .IndexInput ;
3131import org .apache .lucene .util .Bits ;
3232import org .elasticsearch .core .IOUtils ;
33+ import org .elasticsearch .index .codec .vectors .GenericFlatVectorReaders ;
3334import org .elasticsearch .search .vectors .IVFKnnSearchStrategy ;
3435
3536import java .io .Closeable ;
3637import java .io .IOException ;
3738import java .util .ArrayList ;
3839import java .util .Collections ;
39- import java .util .HashMap ;
4040import java .util .List ;
4141import java .util .Map ;
4242
4949 */
5050public abstract class IVFVectorsReader extends KnnVectorsReader {
5151
52- private record FlatVectorsReaderKey (String formatName , boolean useDirectIO ) {
53- private FlatVectorsReaderKey (FieldEntry entry ) {
54- this (entry .rawVectorFormatName , entry .useDirectIOReads );
55- }
56-
57- @ Override
58- public String toString () {
59- return formatName + (useDirectIO ? " with Direct IO" : "" );
60- }
61- }
62-
6352 private final IndexInput ivfCentroids , ivfClusters ;
6453 private final SegmentReadState state ;
6554 private final FieldInfos fieldInfos ;
6655 protected final IntObjectHashMap <FieldEntry > fields ;
67- private final Map <FlatVectorsReaderKey , FlatVectorsReader > rawVectorReaders ;
68-
69- @ FunctionalInterface
70- public interface GetFormatReader {
71- FlatVectorsReader getReader (String formatName , boolean useDirectIO ) throws IOException ;
72- }
56+ private final GenericFlatVectorReaders genericReaders ;
7357
7458 @ SuppressWarnings ("this-escape" )
75- protected IVFVectorsReader (SegmentReadState state , GetFormatReader getFormatReader ) throws IOException {
59+ protected IVFVectorsReader (SegmentReadState state , GenericFlatVectorReaders . LoadFlatVectorsReader loadReader ) throws IOException {
7660 this .state = state ;
7761 this .fieldInfos = state .fieldInfos ;
7862 this .fields = new IntObjectHashMap <>();
63+ this .genericReaders = new GenericFlatVectorReaders ();
7964 String meta = IndexFileNames .segmentFileName (
8065 state .segmentInfo .name ,
8166 state .segmentSuffix ,
@@ -86,7 +71,6 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead
8671 boolean success = false ;
8772 try (ChecksumIndexInput ivfMeta = state .directory .openChecksumInput (meta )) {
8873 Throwable priorE = null ;
89- Map <FlatVectorsReaderKey , FlatVectorsReader > readers = null ;
9074 try {
9175 versionMeta = CodecUtil .checkIndexHeader (
9276 ivfMeta ,
@@ -96,13 +80,12 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead
9680 state .segmentInfo .getId (),
9781 state .segmentSuffix
9882 );
99- readers = readFields (ivfMeta , getFormatReader , versionMeta );
83+ readFields (ivfMeta , versionMeta , genericReaders , loadReader );
10084 } catch (Throwable exception ) {
10185 priorE = exception ;
10286 } finally {
10387 CodecUtil .checkFooter (ivfMeta , priorE );
10488 }
105- this .rawVectorReaders = readers ;
10689 ivfCentroids = openDataInput (
10790 state ,
10891 versionMeta ,
@@ -169,30 +152,23 @@ private static IndexInput openDataInput(
169152 }
170153 }
171154
172- private Map <FlatVectorsReaderKey , FlatVectorsReader > readFields (ChecksumIndexInput meta , GetFormatReader loadReader , int versionMeta )
173- throws IOException {
174- Map <FlatVectorsReaderKey , FlatVectorsReader > readers = new HashMap <>();
155+ private void readFields (
156+ ChecksumIndexInput meta ,
157+ int versionMeta ,
158+ GenericFlatVectorReaders genericFields ,
159+ GenericFlatVectorReaders .LoadFlatVectorsReader loadReader
160+ ) throws IOException {
175161 for (int fieldNumber = meta .readInt (); fieldNumber != -1 ; fieldNumber = meta .readInt ()) {
176162 final FieldInfo info = fieldInfos .fieldInfo (fieldNumber );
177163 if (info == null ) {
178164 throw new CorruptIndexException ("Invalid field number: " + fieldNumber , meta );
179165 }
180166
181167 FieldEntry fieldEntry = readField (meta , info , versionMeta );
182- FlatVectorsReaderKey key = new FlatVectorsReaderKey (fieldEntry );
183-
184- FlatVectorsReader reader = readers .get (key );
185- if (reader == null ) {
186- reader = loadReader .getReader (fieldEntry .rawVectorFormatName , fieldEntry .useDirectIOReads );
187- if (reader == null ) {
188- throw new IllegalStateException ("Cannot find flat vector format: " + fieldEntry .rawVectorFormatName );
189- }
190- readers .put (key , reader );
191- }
168+ genericFields .loadField (fieldNumber , fieldEntry , loadReader );
192169
193170 fields .put (info .number , fieldEntry );
194171 }
195- return readers ;
196172 }
197173
198174 private FieldEntry readField (IndexInput input , FieldInfo info , int versionMeta ) throws IOException {
@@ -256,29 +232,17 @@ private static VectorEncoding readVectorEncoding(DataInput input) throws IOExcep
256232
257233 @ Override
258234 public final void checkIntegrity () throws IOException {
259- for (var reader : rawVectorReaders . values ()) {
235+ for (var reader : genericReaders . allReaders ()) {
260236 reader .checkIntegrity ();
261237 }
262238 CodecUtil .checksumEntireFile (ivfCentroids );
263239 CodecUtil .checksumEntireFile (ivfClusters );
264240 }
265241
266- private FieldEntry getFieldEntryOrThrow (String field ) {
267- final FieldInfo info = fieldInfos .fieldInfo (field );
268- final FieldEntry entry ;
269- if (info == null || (entry = fields .get (info .number )) == null ) {
270- throw new IllegalArgumentException ("field=\" " + field + "\" not found" );
271- }
272- return entry ;
273- }
274-
275242 private FlatVectorsReader getReaderForField (String field ) {
276- var readerKey = new FlatVectorsReaderKey (getFieldEntryOrThrow (field ));
277- FlatVectorsReader reader = rawVectorReaders .get (readerKey );
278- if (reader == null ) throw new IllegalArgumentException (
279- "Could not find raw vector format [" + readerKey + "] for field [" + field + "]"
280- );
281- return reader ;
243+ FieldInfo info = fieldInfos .fieldInfo (field );
244+ if (info == null ) throw new IllegalArgumentException ("Could not find field [" + field + "]" );
245+ return genericReaders .getReaderForField (info .number );
282246 }
283247
284248 @ Override
@@ -399,7 +363,7 @@ public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
399363
400364 @ Override
401365 public void close () throws IOException {
402- List <Closeable > closeables = new ArrayList <>(rawVectorReaders . values ());
366+ List <Closeable > closeables = new ArrayList <>(genericReaders . allReaders ());
403367 Collections .addAll (closeables , ivfCentroids , ivfClusters );
404368 IOUtils .close (closeables );
405369 }
@@ -416,7 +380,7 @@ protected record FieldEntry(
416380 long postingListLength ,
417381 float [] globalCentroid ,
418382 float globalCentroidDp
419- ) {
383+ ) implements GenericFlatVectorReaders . Field {
420384 IndexInput centroidSlice (IndexInput centroidFile ) throws IOException {
421385 return centroidFile .slice ("centroids" , centroidOffset , centroidLength );
422386 }
0 commit comments