diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java index c405e0ad33677..8e08eba241f05 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java @@ -41,6 +41,13 @@ public ES92Int7VectorsScorer(IndexInput in, int dimensions) { this.dimensions = dimensions; } + /** + * Checks if the current implementation supports fast native access. + */ + public boolean hasNativeAccess() { + return false; // This class does not support native access + } + /** * compute the quantize distance between the provided quantized query and the quantized vector * that is read from the wrapped {@link IndexInput}. diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java index 6edf60fff1c83..29f9d6d43c130 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java @@ -8,255 +8,32 @@ */ package org.elasticsearch.simdvec.internal; -import jdk.incubator.vector.ByteVector; -import jdk.incubator.vector.FloatVector; -import jdk.incubator.vector.IntVector; -import jdk.incubator.vector.ShortVector; -import jdk.incubator.vector.Vector; -import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorShape; -import jdk.incubator.vector.VectorSpecies; - import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.VectorUtil; -import org.elasticsearch.simdvec.ES92Int7VectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; -import java.nio.ByteOrder; - -import static java.nio.ByteOrder.LITTLE_ENDIAN; -import static jdk.incubator.vector.VectorOperators.ADD; -import static jdk.incubator.vector.VectorOperators.B2I; -import static jdk.incubator.vector.VectorOperators.B2S; -import static jdk.incubator.vector.VectorOperators.S2I; -import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; -import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; /** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/ -public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer { - - private static final VectorSpecies BYTE_SPECIES_64 = ByteVector.SPECIES_64; - private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; - - private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; - private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; - - private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; - private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; - private static final VectorSpecies INT_SPECIES_512 = IntVector.SPECIES_512; - - private static final int VECTOR_BITSIZE; - private static final VectorSpecies FLOAT_SPECIES; - private static final VectorSpecies INT_SPECIES; - - static { - // default to platform supported bitsize - VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize(); - FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE)); - INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE)); - } - - private final MemorySegment memorySegment; +public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer { public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { - super(in, dimensions); - this.memorySegment = memorySegment; + super(in, dimensions, memorySegment); } @Override - public long int7DotProduct(byte[] q) throws IOException { - assert dimensions == q.length; - int i = 0; - int res = 0; - // only vectorize if we'll at least enter the loop a single time - if (dimensions >= 16) { - // compute vectorized dot product consistent with VPDPBUSD instruction - if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES_128.loopBound(dimensions); - res += dotProductBody512(q, i); - } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES_64.loopBound(dimensions); - res += dotProductBody256(q, i); - } else { - // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); - res += dotProductBody128(q, i); - } - // scalar tail - while (i < dimensions) { - res += in.readByte() * q[i++]; - } - return res; - } else { - return super.int7DotProduct(q); - } - } - - private int dotProductBody512(byte[] q, int limit) throws IOException { - IntVector acc = IntVector.zero(INT_SPECIES_512); - long offset = in.getFilePointer(); - for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); - - // 16-bit multiply: avoid AVX-512 heavy multiply on zmm - Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); - Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); - Vector prod16 = va16.mul(vb16); - - // 32-bit add - Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); - acc = acc.add(prod32); - } - - in.seek(offset + limit); // advance the input stream - // reduce - return acc.reduceLanes(ADD); + public boolean hasNativeAccess() { + return false; // This class does not support native access } - private int dotProductBody256(byte[] q, int limit) throws IOException { - IntVector acc = IntVector.zero(INT_SPECIES_256); - long offset = in.getFilePointer(); - for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); - - // 32-bit multiply and add into accumulator - Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); - Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); - acc = acc.add(va32.mul(vb32)); - } - in.seek(offset + limit); - // reduce - return acc.reduceLanes(ADD); - } - - private int dotProductBody128(byte[] q, int limit) throws IOException { - IntVector acc = IntVector.zero(IntVector.SPECIES_128); - long offset = in.getFilePointer(); - // 4 bytes at a time (re-loading half the vector each time!) - for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { - // load 8 bytes - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); - - // process first "half" only: 16-bit multiply - Vector va16 = va8.convert(B2S, 0); - Vector vb16 = vb8.convert(B2S, 0); - Vector prod16 = va16.mul(vb16); - - // 32-bit add - acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); - } - in.seek(offset + limit); - // reduce - return acc.reduceLanes(ADD); + @Override + public long int7DotProduct(byte[] q) throws IOException { + return panamaInt7DotProduct(q); } @Override public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException { - assert dimensions == q.length; - // only vectorize if we'll at least enter the loop a single time - if (dimensions >= 16) { - // compute vectorized dot product consistent with VPDPBUSD instruction - if (VECTOR_BITSIZE >= 512) { - dotProductBody512Bulk(q, count, scores); - } else if (VECTOR_BITSIZE == 256) { - dotProductBody256Bulk(q, count, scores); - } else { - // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - dotProductBody128Bulk(q, count, scores); - } - } else { - int7DotProductBulk(q, count, scores); - } - } - - private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException { - int limit = BYTE_SPECIES_128.loopBound(dimensions); - for (int iter = 0; iter < count; iter++) { - IntVector acc = IntVector.zero(INT_SPECIES_512); - long offset = in.getFilePointer(); - int i = 0; - for (; i < limit; i += BYTE_SPECIES_128.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); - - // 16-bit multiply: avoid AVX-512 heavy multiply on zmm - Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); - Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); - Vector prod16 = va16.mul(vb16); - - // 32-bit add - Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); - acc = acc.add(prod32); - } - - in.seek(offset + limit); // advance the input stream - // reduce - long res = acc.reduceLanes(ADD); - for (; i < dimensions; i++) { - res += in.readByte() * q[i]; - } - scores[iter] = res; - } - } - - private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException { - int limit = BYTE_SPECIES_128.loopBound(dimensions); - for (int iter = 0; iter < count; iter++) { - IntVector acc = IntVector.zero(INT_SPECIES_256); - long offset = in.getFilePointer(); - int i = 0; - for (; i < limit; i += BYTE_SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); - - // 32-bit multiply and add into accumulator - Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); - Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); - acc = acc.add(va32.mul(vb32)); - } - in.seek(offset + limit); - // reduce - long res = acc.reduceLanes(ADD); - for (; i < dimensions; i++) { - res += in.readByte() * q[i]; - } - scores[iter] = res; - } - } - - private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException { - int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); - for (int iter = 0; iter < count; iter++) { - IntVector acc = IntVector.zero(IntVector.SPECIES_128); - long offset = in.getFilePointer(); - // 4 bytes at a time (re-loading half the vector each time!) - int i = 0; - for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { - // load 8 bytes - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); - ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); - - // process first "half" only: 16-bit multiply - Vector va16 = va8.convert(B2S, 0); - Vector vb16 = vb8.convert(B2S, 0); - Vector prod16 = va16.mul(vb16); - - // 32-bit add - acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); - } - in.seek(offset + limit); - // reduce - long res = acc.reduceLanes(ADD); - for (; i < dimensions; i++) { - res += in.readByte() * q[i]; - } - scores[iter] = res; - } + panamaInt7DotProductBulk(q, count, scores); } @Override @@ -281,72 +58,4 @@ public void scoreBulk( scores ); } - - private void applyCorrectionsBulk( - float queryLowerInterval, - float queryUpperInterval, - int queryComponentSum, - float queryAdditionalCorrection, - VectorSimilarityFunction similarityFunction, - float centroidDp, - float[] scores - ) throws IOException { - int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); - int i = 0; - long offset = in.getFilePointer(); - float ay = queryLowerInterval; - float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE; - float y1 = queryComponentSum; - for (; i < limit; i += FLOAT_SPECIES.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); - var lx = FloatVector.fromMemorySegment( - FLOAT_SPECIES, - memorySegment, - offset + 4 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ).sub(ax).mul(SEVEN_BIT_SCALE); - var targetComponentSums = IntVector.fromMemorySegment( - INT_SPECIES, - memorySegment, - offset + 8 * BULK_SIZE + i * Integer.BYTES, - ByteOrder.LITTLE_ENDIAN - ).convert(VectorOperators.I2F, 0); - var additionalCorrections = FloatVector.fromMemorySegment( - FLOAT_SPECIES, - memorySegment, - offset + 12 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ); - var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); - // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * - // qcDist; - var res1 = ax.mul(ay).mul(dimensions); - var res2 = lx.mul(ay).mul(targetComponentSums); - var res3 = ax.mul(ly).mul(y1); - var res4 = lx.mul(ly).mul(qcDist); - var res = res1.add(res2).add(res3).add(res4); - // For euclidean, we need to invert the score and apply the additional correction, which is - // assumed to be the squared l2norm of the centroid centered vectors. - if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); - res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); - res.intoArray(scores, i); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - res.intoArray(scores, i); - // not sure how to do it better - for (int j = 0; j < FLOAT_SPECIES.length(); j++) { - scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); - } - } else { - res = res.add(1f).mul(0.5f).max(0); - res.intoArray(scores, i); - } - } - } - in.seek(offset + 16L * BULK_SIZE); - } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java new file mode 100644 index 0000000000000..bdd8b17900c0a --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92PanamaInt7VectorsScorer.java @@ -0,0 +1,327 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; + +import static java.nio.ByteOrder.LITTLE_ENDIAN; +import static jdk.incubator.vector.VectorOperators.ADD; +import static jdk.incubator.vector.VectorOperators.B2I; +import static jdk.incubator.vector.VectorOperators.B2S; +import static jdk.incubator.vector.VectorOperators.S2I; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/ +abstract class MemorySegmentES92PanamaInt7VectorsScorer extends ES92Int7VectorsScorer { + + private static final VectorSpecies BYTE_SPECIES_64 = ByteVector.SPECIES_64; + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + + private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; + private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; + + private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; + private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; + private static final VectorSpecies INT_SPECIES_512 = IntVector.SPECIES_512; + + private static final int VECTOR_BITSIZE; + private static final VectorSpecies FLOAT_SPECIES; + private static final VectorSpecies INT_SPECIES; + + static { + // default to platform supported bitsize + VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize(); + FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE)); + INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE)); + } + + protected final MemorySegment memorySegment; + + protected MemorySegmentES92PanamaInt7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions); + this.memorySegment = memorySegment; + } + + protected long panamaInt7DotProduct(byte[] q) throws IOException { + assert dimensions == q.length; + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time + if (dimensions >= 16) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (VECTOR_BITSIZE >= 512) { + i += BYTE_SPECIES_128.loopBound(dimensions); + res += dotProductBody512(q, i); + } else if (VECTOR_BITSIZE == 256) { + i += BYTE_SPECIES_64.loopBound(dimensions); + res += dotProductBody256(q, i); + } else { + // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" + i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); + res += dotProductBody128(q, i); + } + // scalar tail + while (i < dimensions) { + res += in.readByte() * q[i++]; + } + return res; + } else { + return super.int7DotProduct(q); + } + } + + private int dotProductBody512(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(INT_SPECIES_512); + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); + + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm + Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); + acc = acc.add(prod32); + } + + in.seek(offset + limit); // advance the input stream + // reduce + return acc.reduceLanes(ADD); + } + + private int dotProductBody256(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(INT_SPECIES_256); + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // 32-bit multiply and add into accumulator + Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); + Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); + acc = acc.add(va32.mul(vb32)); + } + in.seek(offset + limit); + // reduce + return acc.reduceLanes(ADD); + } + + private int dotProductBody128(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(IntVector.SPECIES_128); + long offset = in.getFilePointer(); + // 4 bytes at a time (re-loading half the vector each time!) + for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { + // load 8 bytes + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // process first "half" only: 16-bit multiply + Vector va16 = va8.convert(B2S, 0); + Vector vb16 = vb8.convert(B2S, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); + } + in.seek(offset + limit); + // reduce + return acc.reduceLanes(ADD); + } + + protected void panamaInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException { + assert dimensions == q.length; + // only vectorize if we'll at least enter the loop a single time + if (dimensions >= 16) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (VECTOR_BITSIZE >= 512) { + dotProductBody512Bulk(q, count, scores); + } else if (VECTOR_BITSIZE == 256) { + dotProductBody256Bulk(q, count, scores); + } else { + // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" + dotProductBody128Bulk(q, count, scores); + } + } else { + int7DotProductBulk(q, count, scores); + } + } + + private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_128.loopBound(dimensions); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(INT_SPECIES_512); + long offset = in.getFilePointer(); + int i = 0; + for (; i < limit; i += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); + + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm + Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); + acc = acc.add(prod32); + } + + in.seek(offset + limit); // advance the input stream + // reduce + long res = acc.reduceLanes(ADD); + for (; i < dimensions; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_128.loopBound(dimensions); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(INT_SPECIES_256); + long offset = in.getFilePointer(); + int i = 0; + for (; i < limit; i += BYTE_SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // 32-bit multiply and add into accumulator + Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); + Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); + acc = acc.add(va32.mul(vb32)); + } + in.seek(offset + limit); + // reduce + long res = acc.reduceLanes(ADD); + for (; i < dimensions; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(IntVector.SPECIES_128); + long offset = in.getFilePointer(); + // 4 bytes at a time (re-loading half the vector each time!) + int i = 0; + for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { + // load 8 bytes + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // process first "half" only: 16-bit multiply + Vector va16 = va8.convert(B2S, 0); + Vector vb16 = vb8.convert(B2S, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); + } + in.seek(offset + limit); + // reduce + long res = acc.reduceLanes(ADD); + for (; i < dimensions; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + protected void applyCorrectionsBulk( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE; + float y1 = queryComponentSum; + for (; i < limit; i += FLOAT_SPECIES.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax).mul(SEVEN_BIT_SCALE); + var targetComponentSums = IntVector.fromMemorySegment( + INT_SPECIES, + memorySegment, + offset + 8 * BULK_SIZE + i * Integer.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 12 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 16L * BULK_SIZE); + } +} diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java index 1b60471b33b59..8abf05098deaf 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java @@ -8,46 +8,39 @@ */ package org.elasticsearch.simdvec.internal; -import jdk.incubator.vector.FloatVector; -import jdk.incubator.vector.IntVector; -import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorShape; -import jdk.incubator.vector.VectorSpecies; - import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.VectorUtil; -import org.elasticsearch.simdvec.ES92Int7VectorsScorer; +import org.elasticsearch.nativeaccess.NativeAccess; import java.io.IOException; import java.lang.foreign.MemorySegment; -import java.nio.ByteOrder; - -import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; -import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; /** Native / panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/ -public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer { +public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer { - private static final VectorSpecies FLOAT_SPECIES; - private static final VectorSpecies INT_SPECIES; + private static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); - static { - // default to platform supported bitsize - final int vectorBitSize = VectorShape.preferredShape().vectorBitSize(); - FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(vectorBitSize)); - INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(vectorBitSize)); + public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions, memorySegment); } - private final MemorySegment memorySegment; - - public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { - super(in, dimensions); - this.memorySegment = memorySegment; + @Override + public boolean hasNativeAccess() { + return NATIVE_SUPPORTED; } @Override public long int7DotProduct(byte[] q) throws IOException { + assert q.length == dimensions; + if (NATIVE_SUPPORTED) { + return nativeInt7DotProduct(q); + } else { + return panamaInt7DotProduct(q); + } + + } + + private long nativeInt7DotProduct(byte[] q) throws IOException { final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions); final MemorySegment querySegment = MemorySegment.ofArray(q); final long res = Similarities.dotProduct7u(segment, querySegment, dimensions); @@ -57,9 +50,14 @@ public long int7DotProduct(byte[] q) throws IOException { @Override public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException { - // TODO: can we speed up bulks in native code? - for (int i = 0; i < count; i++) { - scores[i] = int7DotProduct(q); + assert q.length == dimensions; + if (NATIVE_SUPPORTED) { + // TODO: can we speed up bulks in native code? + for (int i = 0; i < count; i++) { + scores[i] = nativeInt7DotProduct(q); + } + } else { + panamaInt7DotProductBulk(q, count, scores); } } @@ -85,72 +83,4 @@ public void scoreBulk( scores ); } - - private void applyCorrectionsBulk( - float queryLowerInterval, - float queryUpperInterval, - int queryComponentSum, - float queryAdditionalCorrection, - VectorSimilarityFunction similarityFunction, - float centroidDp, - float[] scores - ) throws IOException { - int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); - int i = 0; - long offset = in.getFilePointer(); - float ay = queryLowerInterval; - float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE; - float y1 = queryComponentSum; - for (; i < limit; i += FLOAT_SPECIES.length()) { - var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); - var lx = FloatVector.fromMemorySegment( - FLOAT_SPECIES, - memorySegment, - offset + 4 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ).sub(ax).mul(SEVEN_BIT_SCALE); - var targetComponentSums = IntVector.fromMemorySegment( - INT_SPECIES, - memorySegment, - offset + 8 * BULK_SIZE + i * Integer.BYTES, - ByteOrder.LITTLE_ENDIAN - ).convert(VectorOperators.I2F, 0); - var additionalCorrections = FloatVector.fromMemorySegment( - FLOAT_SPECIES, - memorySegment, - offset + 12 * BULK_SIZE + i * Float.BYTES, - ByteOrder.LITTLE_ENDIAN - ); - var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); - // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * - // qcDist; - var res1 = ax.mul(ay).mul(dimensions); - var res2 = lx.mul(ay).mul(targetComponentSums); - var res3 = ax.mul(ly).mul(y1); - var res4 = lx.mul(ly).mul(qcDist); - var res = res1.add(res2).add(res3).add(res4); - // For euclidean, we need to invert the score and apply the additional correction, which is - // assumed to be the squared l2norm of the centroid centered vectors. - if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); - res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); - res.intoArray(scores, i); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - res.intoArray(scores, i); - // not sure how to do it better - for (int j = 0; j < FLOAT_SPECIES.length(); j++) { - scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); - } - } else { - res = res.add(1f).mul(0.5f).max(0); - res.intoArray(scores, i); - } - } - } - in.seek(offset + 16L * BULK_SIZE); - } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java index 31ef6092539e7..5b82fb9ea20e8 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java @@ -27,6 +27,15 @@ public class ES92Int7VectorScorerTests extends BaseVectorizationTests { + public boolean hasNativeAccess() { + var jdkVersion = Runtime.version().feature(); + var arch = System.getProperty("os.arch"); + var osName = System.getProperty("os.name"); + return (jdkVersion >= 22 + && (arch.equals("aarch64") && (osName.startsWith("Mac") || osName.equals("Linux")) + || arch.equals("amd64") && osName.equals("Linux"))); + } + public void testInt7DotProduct() throws Exception { // only even dimensions are supported final int dimensions = random().nextInt(1, 1000) * 2; @@ -52,7 +61,9 @@ public void testInt7DotProduct() throws Exception { final IndexInput slice = in.slice("test", 0, (long) dimensions * numVectors); final IndexInput slice2 = in.slice("test2", 0, (long) dimensions * numVectors); final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(slice, dimensions); + assertFalse(defaultScorer.hasNativeAccess()); final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice2, dimensions); + assertEquals(panamaScorer.hasNativeAccess(), hasNativeAccess()); for (int i = 0; i < numVectors; i++) { in.readBytes(vector, 0, dimensions); long val = VectorUtil.dotProduct(vector, query); @@ -119,7 +130,9 @@ public void testInt7Score() throws Exception { // padding bytes. final IndexInput slice = in.slice("test", 0, (long) (dimensions + 16) * numVectors); final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(in, dimensions); + assertFalse(defaultScorer.hasNativeAccess()); final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice, dimensions); + assertEquals(panamaScorer.hasNativeAccess(), hasNativeAccess()); for (int i = 0; i < numVectors; i++) { float scoreDefault = defaultScorer.score( qQuery, @@ -198,7 +211,9 @@ public void testInt7ScoreBulk() throws Exception { // padding bytes. final IndexInput slice = in.slice("test", 0, (long) (dimensions + 16) * numVectors); final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(in, dimensions); + assertFalse(defaultScorer.hasNativeAccess()); final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice, dimensions); + assertEquals(panamaScorer.hasNativeAccess(), hasNativeAccess()); float[] scoresDefault = new float[ES91Int4VectorsScorer.BULK_SIZE]; float[] scoresPanama = new float[ES91Int4VectorsScorer.BULK_SIZE]; for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {