@@ -129,44 +129,6 @@ final void checkOrdinal(int ord) {
129129 private static final ValueLayout .OfInt INT_UNALIGNED_LE =
130130 JAVA_INT_UNALIGNED .withOrder (ByteOrder .LITTLE_ENDIAN );
131131
132- // XXX I need to return something wraps the MemorySegment and can produce the
133- // corrective terms
134- // on demand. rep is probably (MemorySegment, MemorySegment) with a slice for
135- // the corrective terms.
136- @ SuppressWarnings ("restricted" )
137- MemorySegment getVector (int ord ) throws IOException {
138- checkOrdinal (ord );
139- long byteOffset = (long ) ord * nodeSize ;
140- MemorySegment vector = input .segmentSliceOrNull (byteOffset , vectorByteSize );
141- if (vector == null ) {
142- if (scratch == null ) {
143- scratch = new byte [nodeSize ];
144- }
145- input .readBytes (byteOffset , scratch , 0 , nodeSize );
146- vector = MemorySegment .ofArray (scratch ).reinterpret (vectorByteSize );
147- }
148- return vector ;
149- }
150-
151- @ SuppressWarnings ("restricted" )
152- OptimizedScalarQuantizer .QuantizationResult getCorrectiveTerms (int ord ) throws IOException {
153- checkOrdinal (ord );
154- long byteOffset = (long ) ord * nodeSize + vectorByteSize ;
155- MemorySegment node = input .segmentSliceOrNull (byteOffset , CORRECTIVE_TERMS_SIZE );
156- if (node == null ) {
157- if (scratch == null ) {
158- scratch = new byte [nodeSize ];
159- }
160- input .readBytes (byteOffset , scratch , 0 , CORRECTIVE_TERMS_SIZE );
161- node = MemorySegment .ofArray (scratch ).reinterpret (CORRECTIVE_TERMS_SIZE );
162- }
163- return new OptimizedScalarQuantizer .QuantizationResult (
164- Float .intBitsToFloat (node .get (INT_UNALIGNED_LE , 0 )),
165- Float .intBitsToFloat (node .get (INT_UNALIGNED_LE , Integer .BYTES )),
166- Float .intBitsToFloat (node .get (INT_UNALIGNED_LE , Integer .BYTES * 2 )),
167- node .get (INT_UNALIGNED_LE , Integer .BYTES * 3 ));
168- }
169-
170132 record Node (
171133 MemorySegment vector ,
172134 float lowerInterval ,
@@ -188,9 +150,8 @@ Node getNode(int ord) throws IOException {
188150 }
189151 // XXX investigate reordering the vector so that corrective terms appear first.
190152 // we're forced to read them immediately to avoid creating a second memory
191- // segment which is
192- // not cheap, so they might as well be read first to avoid additional memory
193- // latency.
153+ // segment which is not cheap, so they might as well be read first to avoid
154+ // additional memory latency.
194155 return new Node (
195156 vector .reinterpret (vectorByteSize ),
196157 Float .intBitsToFloat (vector .get (INT_UNALIGNED_LE , vectorByteSize )),
@@ -260,7 +221,7 @@ public float score(int node) throws IOException {
260221 }
261222 }
262223
263- private record RandomVectorScorerSupplierImpl (
224+ record RandomVectorScorerSupplierImpl (
264225 VectorSimilarityFunction similarityFunction ,
265226 QuantizedByteVectorValues values ,
266227 MemorySegmentAccessInput input )
@@ -293,23 +254,37 @@ private static class UpdateableRandomVectorScorerImpl extends RandomVectorScorer
293254 @ Override
294255 public void setScoringOrdinal (int ord ) throws IOException {
295256 checkOrdinal (ord );
296- query = getVector (ord );
297- queryCorrectiveTerms = getCorrectiveTerms (ord );
257+ Node node = getNode (ord );
258+ query = node .vector ();
259+ queryCorrectiveTerms =
260+ new OptimizedScalarQuantizer .QuantizationResult (
261+ node .lowerInterval (),
262+ node .upperInterval (),
263+ node .additionalCorrection (),
264+ node .componentSum ());
298265 }
299266
300267 @ Override
301268 public float score (int node ) throws IOException {
302- MemorySegment doc = getVector (node );
269+ Node doc = getNode (node );
303270 float dotProduct =
304271 switch (getScalarEncoding ()) {
305- case UNSIGNED_BYTE -> PanamaVectorUtilSupport .uint8DotProduct (query , doc );
306- case SEVEN_BIT -> PanamaVectorUtilSupport .uint8DotProduct (query , doc );
307- case PACKED_NIBBLE -> PanamaVectorUtilSupport .int4DotProductBothPacked (query , doc );
272+ case UNSIGNED_BYTE -> PanamaVectorUtilSupport .uint8DotProduct (query , doc .vector ());
273+ case SEVEN_BIT -> PanamaVectorUtilSupport .uint8DotProduct (query , doc .vector ());
274+ case PACKED_NIBBLE ->
275+ PanamaVectorUtilSupport .int4DotProductBothPacked (query , doc .vector ());
308276 };
309277 // Call getCorrectiveTerms() after computing dot product since corrective terms
310- // bytes appear
311- // after the vector bytes, so this sequence of calls is more cache friendly.
312- return getSimilarity ().score (dotProduct , queryCorrectiveTerms , getCorrectiveTerms (node ));
278+ // bytes appear after the vector bytes, so this sequence of calls is more cache
279+ // friendly.
280+ return getSimilarity ()
281+ .score (
282+ dotProduct ,
283+ queryCorrectiveTerms ,
284+ doc .lowerInterval (),
285+ doc .upperInterval (),
286+ doc .additionalCorrection (),
287+ doc .componentSum ());
313288 }
314289 }
315290}
0 commit comments