Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
Expand Down Expand Up @@ -66,7 +67,7 @@ public ByteDenseVectorFunction(
ElementType... allowedTypes
) {
super(scoreScript, field);
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
field.getElement().checkDimensions(field.get().getDims(), queryVector.size());
float[] floatValues = new float[queryVector.size()];
double queryMagnitude = 0;
for (int i = 0; i < queryVector.size(); i++) {
Expand All @@ -76,7 +77,7 @@ public ByteDenseVectorFunction(
}
queryMagnitude = Math.sqrt(queryMagnitude);

switch (ElementType.checkValidVector(floatValues, allowedTypes)) {
switch (Element.checkValidVector(floatValues, allowedTypes)) {
case FLOAT:
byteQueryVector = null;
floatQueryVector = floatValues;
Expand Down Expand Up @@ -149,7 +150,7 @@ public FloatDenseVectorFunction(
queryMagnitude += value * value;
}
queryMagnitude = Math.sqrt(queryMagnitude);
field.getElementType().checkVectorBounds(this.queryVector);
field.getElement().checkVectorBounds(this.queryVector);

if (normalizeQuery) {
for (int dim = 0; dim < this.queryVector.length; dim++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.util.Iterator;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;

public abstract class DenseVectorDocValuesField extends AbstractScriptFieldFactory<DenseVector>
Expand All @@ -40,6 +41,10 @@ public ElementType getElementType() {
return elementType;
}

public Element getElement() {
return Element.getElement(elementType);
}

@Override
public int size() {
return isEmpty() ? 0 : 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

package org.elasticsearch.script.field.vectors;

import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues;
import org.elasticsearch.script.field.AbstractScriptFieldFactory;
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
import org.elasticsearch.script.field.Field;

import java.util.Iterator;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;

public abstract class RankVectorsDocValuesField extends AbstractScriptFieldFactory<RankVectors>
implements
Field<RankVectors>,
Expand All @@ -40,6 +40,10 @@ public ElementType getElementType() {
return elementType;
}

public Element getElement() {
return Element.getElement(elementType);
}

/**
* Get the DenseVector for a document if one exists, DenseVector.EMPTY otherwise
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public byte[] asByteVector() {
if (byteVector != null) {
return byteVector;
}
DenseVectorFieldMapper.ElementType.BYTE.checkVectorBounds(floatVector);
DenseVectorFieldMapper.BYTE_ELEMENT.checkVectorBounds(floatVector);
byte[] vec = new byte[floatVector.length];
for (int i = 0; i < floatVector.length; i++) {
vec[i] = (byte) floatVector[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField;
Expand Down Expand Up @@ -240,11 +241,12 @@ public static BytesRef mockEncodeDenseVector(float[] values, ElementType element
if (elementType == ElementType.BIT) {
dims *= Byte.SIZE;
}
Element element = Element.getElement(elementType);
int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
? elementType.getNumBytes(dims) + DenseVectorFieldMapper.MAGNITUDE_BYTES
: elementType.getNumBytes(dims);
? element.getNumBytes(dims) + DenseVectorFieldMapper.MAGNITUDE_BYTES
: element.getNumBytes(dims);
double dotProduct = 0f;
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
ByteBuffer byteBuffer = element.createByteBuffer(indexVersion, numBytes);
for (float value : values) {
if (elementType == ElementType.FLOAT) {
byteBuffer.putFloat(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void testVectorDecodingWithOffset() {
),
DenseVectorFieldMapper.LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION
)) {
ByteBuffer byteBuffer = DenseVectorFieldMapper.ElementType.FLOAT.createByteBuffer(version, 20);
ByteBuffer byteBuffer = DenseVectorFieldMapper.FLOAT_ELEMENT.createByteBuffer(version, 20);
double magnitude = 0.0;
for (float f : inputFloats) {
byteBuffer.putFloat(f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.index.mapper.vectors.SyntheticVectorsPatchFieldLoader;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.license.LicenseUtils;
Expand Down Expand Up @@ -77,7 +79,7 @@ public static class Builder extends FieldMapper.Builder {
}
return elementType;
},
m -> toType(m).fieldType().elementType,
m -> toType(m).fieldType().element.elementType(),
XContentBuilder::field,
Objects::toString
);
Expand Down Expand Up @@ -160,19 +162,19 @@ public RankVectorsFieldMapper build(MapperBuilderContext context) {
}

public static final class RankVectorsFieldType extends SimpleMappedFieldType {
private final DenseVectorFieldMapper.ElementType elementType;
private final Element element;
private final Integer dims;
private final XPackLicenseState licenseState;

public RankVectorsFieldType(
String name,
DenseVectorFieldMapper.ElementType elementType,
ElementType elementType,
Integer dims,
XPackLicenseState licenseState,
Map<String, String> meta
) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.elementType = elementType;
this.element = Element.getElement(elementType);
this.dims = dims;
this.licenseState = licenseState;
}
Expand Down Expand Up @@ -228,7 +230,7 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext
if (RANK_VECTORS_FEATURE.check(licenseState) == false) {
throw LicenseUtils.newComplianceException("Rank Vectors");
}
return new RankVectorsIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD, dims, elementType);
return new RankVectorsIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD, dims, element.elementType());
}

@Override
Expand All @@ -246,7 +248,7 @@ int getVectorDimensions() {
}

DenseVectorFieldMapper.ElementType getElementType() {
return elementType;
return element.elementType();
}
}

Expand Down Expand Up @@ -303,7 +305,7 @@ public void parse(DocumentParserContext context) throws IOException {
if (fieldType().dims == null) {
int currentDims = -1;
while (XContentParser.Token.END_ARRAY != context.parser().nextToken()) {
int dims = fieldType().elementType.parseDimensionCount(context);
int dims = fieldType().element.parseDimensionCount(context);
if (currentDims == -1) {
currentDims = dims;
} else if (currentDims != dims) {
Expand All @@ -319,10 +321,10 @@ public void parse(DocumentParserContext context) throws IOException {
return;
}
int dims = fieldType().dims;
DenseVectorFieldMapper.ElementType elementType = fieldType().elementType;
Element element = fieldType().element;
List<VectorData> vectors = new ArrayList<>();
while (XContentParser.Token.END_ARRAY != context.parser().nextToken()) {
VectorData vector = elementType.parseKnnVector(context, dims, (i, b) -> {
VectorData vector = element.parseKnnVector(context, dims, (i, b) -> {
if (b) {
checkDimensionMatches(i, context);
} else {
Expand All @@ -331,12 +333,12 @@ public void parse(DocumentParserContext context) throws IOException {
}, null);
vectors.add(vector);
}
int bufferSize = elementType.getNumBytes(dims) * vectors.size();
int bufferSize = element.getNumBytes(dims) * vectors.size();
ByteBuffer buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN);
ByteBuffer magnitudeBuffer = ByteBuffer.allocate(vectors.size() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (VectorData vector : vectors) {
vector.addToBuffer(buffer);
magnitudeBuffer.putFloat((float) Math.sqrt(elementType.computeSquaredMagnitude(vector)));
magnitudeBuffer.putFloat((float) Math.sqrt(element.computeSquaredMagnitude(vector)));
}
String vectorFieldName = fieldType().name();
String vectorMagnitudeFieldName = vectorFieldName + VECTOR_MAGNITUDES_SUFFIX;
Expand Down Expand Up @@ -444,15 +446,15 @@ public void write(XContentBuilder b) throws IOException {
b.startArray(leafName());
BytesRef ref = values.binaryValue();
ByteBuffer byteBuffer = ByteBuffer.wrap(ref.bytes, ref.offset, ref.length).order(ByteOrder.LITTLE_ENDIAN);
assert ref.length % fieldType().elementType.getNumBytes(fieldType().dims) == 0;
int numVecs = ref.length / fieldType().elementType.getNumBytes(fieldType().dims);
assert ref.length % fieldType().element.getNumBytes(fieldType().dims) == 0;
int numVecs = ref.length / fieldType().element.getNumBytes(fieldType().dims);
for (int i = 0; i < numVecs; i++) {
b.startArray();
int dims = fieldType().elementType == DenseVectorFieldMapper.ElementType.BIT
int dims = fieldType().element.elementType() == DenseVectorFieldMapper.ElementType.BIT
? fieldType().dims / Byte.SIZE
: fieldType().dims;
for (int dim = 0; dim < dims; dim++) {
fieldType().elementType.readAndWriteValue(byteBuffer, b);
fieldType().element.readAndWriteValue(byteBuffer, b);
}
b.endArray();
}
Expand All @@ -468,15 +470,15 @@ private List<List<?>> copyVectorsAsList() throws IOException {
assert hasValue : "rank vector is null";
BytesRef ref = values.binaryValue();
ByteBuffer byteBuffer = ByteBuffer.wrap(ref.bytes, ref.offset, ref.length).order(ByteOrder.LITTLE_ENDIAN);
assert ref.length % fieldType().elementType.getNumBytes(fieldType().dims) == 0;
int numVecs = ref.length / fieldType().elementType.getNumBytes(fieldType().dims);
assert ref.length % fieldType().element.getNumBytes(fieldType().dims) == 0;
int numVecs = ref.length / fieldType().element.getNumBytes(fieldType().dims);
List<List<?>> vectors = new ArrayList<>(numVecs);
for (int i = 0; i < numVecs; i++) {
int dims = fieldType().elementType == DenseVectorFieldMapper.ElementType.BIT
int dims = fieldType().element.elementType() == DenseVectorFieldMapper.ElementType.BIT
? fieldType().dims / Byte.SIZE
: fieldType().dims;

switch (fieldType().elementType) {
switch (fieldType().element.elementType()) {
case FLOAT -> {
List<Float> vec = new ArrayList<>(dims);
for (int dim = 0; dim < dims; dim++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesFiel
if (queryVector.isEmpty()) {
throw new IllegalArgumentException("The query vector is empty.");
}
field.getElementType().checkDimensions(field.get().getDims(), queryVector.get(0).size());
field.getElement().checkDimensions(field.get().getDims(), queryVector.get(0).size());
this.queryVector = new byte[queryVector.size()][queryVector.get(0).size()];
float[] validateValues = new float[queryVector.size()];
int lastSize = -1;
Expand All @@ -72,7 +72,7 @@ public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesFiel
this.queryVector[i][j] = value;
validateValues[i] = number.floatValue();
}
field.getElementType().checkVectorBounds(validateValues);
field.getElement().checkVectorBounds(validateValues);
}
}

Expand Down Expand Up @@ -118,7 +118,7 @@ public FloatRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesFie
for (int j = 0; j < queryVector.get(i).size(); j++) {
this.queryVector[i][j] = queryVector.get(i).get(j).floatValue();
}
field.getElementType().checkVectorBounds(this.queryVector[i]);
field.getElement().checkVectorBounds(this.queryVector[i]);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.mapper.SourceToParse;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.index.mapper.vectors.SyntheticVectorsMapperTestCase;
import org.elasticsearch.index.query.SearchExecutionContext;
Expand Down Expand Up @@ -232,7 +233,7 @@ public void testNonIndexedVector() throws Exception {
assertThat(fields.get(0), instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = fields.get(0).binaryValue();
assertEquals(ElementType.FLOAT.getNumBytes(validVectors[0].length) * validVectors.length, vectorBR.length);
assertEquals(DenseVectorFieldMapper.FLOAT_ELEMENT.getNumBytes(validVectors[0].length) * validVectors.length, vectorBR.length);
float[][] decodedValues = new float[validVectors.length][];
for (int i = 0; i < validVectors.length; i++) {
decodedValues[i] = new float[validVectors[i].length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues;
import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField;
Expand Down Expand Up @@ -300,8 +301,9 @@ public static BytesRef mockEncodeDenseVector(float[][] values, ElementType eleme
if (elementType == ElementType.BIT) {
dims *= Byte.SIZE;
}
int numBytes = elementType.getNumBytes(dims);
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes * values.length);
Element element = Element.getElement(elementType);
int numBytes = element.getNumBytes(dims);
ByteBuffer byteBuffer = element.createByteBuffer(indexVersion, numBytes * values.length);
for (float[] vector : values) {
for (float value : vector) {
if (elementType == ElementType.FLOAT) {
Expand Down