@@ -81,10 +81,12 @@ float int4QuantizedScore(
8181 }
8282 }
8383
84+ private abstract static class ChildCentroidQueryScorer extends BaseCentroidQueryScorer implements CentroidWClusterOffsetQueryScorer {}
85+
8486 private abstract static class ParentCentroidQueryScorer extends BaseCentroidQueryScorer implements CentroidWChildrenQueryScorer {}
8587
8688 @ Override
87- CentroidQueryScorer getCentroidScorer (
89+ ChildCentroidQueryScorer getChildCentroidScorer (
8890 FieldInfo fieldInfo ,
8991 int numParentCentroids ,
9092 int numCentroids ,
@@ -106,14 +108,16 @@ CentroidQueryScorer getCentroidScorer(
106108 quantized [i ] = (byte ) scratch [i ];
107109 }
108110 final ES91Int4VectorsScorer scorer = ESVectorUtil .getES91Int4VectorsScorer (centroids , fieldInfo .getVectorDimension ());
109- return new BaseCentroidQueryScorer () {
111+ return new ChildCentroidQueryScorer () {
110112 int currentCentroid = -1 ;
111113 private final float [] centroid = new float [fieldInfo .getVectorDimension ()];
112114 private final float [] centroidCorrectiveValues = new float [3 ];
115+ private int clusterOrdinal ;
113116 private final long quantizedVectorByteSize = fieldInfo .getVectorDimension () + 3 * Float .BYTES + Short .BYTES ;
117+ private final long quantizedVectorNodeByteSize = quantizedVectorByteSize + Integer .BYTES ;
114118 private final long parentNodeByteSize = quantizedVectorByteSize + 2 * Integer .BYTES ;
115119 private final long quantizedCentroidsOffset = numParentCentroids * parentNodeByteSize ;
116- private final long rawCentroidsOffset = numParentCentroids * parentNodeByteSize + numCentroids * quantizedVectorByteSize ;
120+ private final long rawCentroidsOffset = numParentCentroids * parentNodeByteSize + numCentroids * quantizedVectorNodeByteSize ;
117121 private final long rawCentroidsByteSize = (long ) Float .BYTES * fieldInfo .getVectorDimension ();
118122
119123 @ Override
@@ -137,16 +141,30 @@ public void bulkScore(NeighborQueue queue, int start, int end) throws IOExceptio
137141 assert start >= 0 ;
138142 assert end > 0 ;
139143 assert start + end <= numCentroids ;
140- centroids .seek (quantizedCentroidsOffset + quantizedVectorByteSize * start );
144+ centroids .seek (quantizedCentroidsOffset + quantizedVectorNodeByteSize * start );
141145 for (int i = start ; i < end ; i ++) {
142146 queue .add (i , score ());
143147 }
144148 }
145149
150+ // TODO: this causes seeks refactor to move this to the end of the block in this file
151+ @ Override
152+ public int getClusterOrdinal (int centroidOrdinal ) throws IOException {
153+ if (centroidOrdinal != currentCentroid ) {
154+ centroids .seek (quantizedCentroidsOffset + quantizedVectorNodeByteSize * centroidOrdinal + quantizedVectorByteSize );
155+ clusterOrdinal = centroids .readInt ();
156+ }
157+ return clusterOrdinal ;
158+ }
159+
146160 private float score () throws IOException {
147161 final float qcDist = scorer .int4DotProduct (quantized );
148162 centroids .readFloats (centroidCorrectiveValues , 0 , 3 );
149163 final int quantizedCentroidComponentSum = Short .toUnsignedInt (centroids .readShort ());
164+
165+ // TODO: should we consider a different format such as moving these to the beginning of the file to benefit bulk read
166+ centroids .skipBytes (Integer .BYTES );
167+
150168 return int4QuantizedScore (
151169 qcDist ,
152170 queryParams ,
@@ -198,7 +216,7 @@ public int size() {
198216
199217 @ Override
200218 public float [] centroid (int centroidOrdinal ) throws IOException {
201- throw new UnsupportedOperationException ("can't score at the parent level" );
219+ throw new IllegalStateException ("can't score at the parent level" );
202220 }
203221
204222 private void readChildDetails (int centroidOrdinal ) throws IOException {
0 commit comments