Skip to content

Commit ccd0a23

Browse files
committed
Propagate changes
1 parent df30d63 commit ccd0a23

File tree

11 files changed

+103
-79
lines changed

11 files changed

+103
-79
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ public static class Builder extends FieldMapper.Builder {
245245
throw new MapperParsingException("invalid element_type [" + o + "]; available types are " + namesToElementType.keySet());
246246
}
247247
return elementType;
248-
}, m -> toType(m).fieldType().elementType, XContentBuilder::field, Objects::toString);
248+
}, m -> toType(m).fieldType().element.elementType(), XContentBuilder::field, Objects::toString);
249249
private final Parameter<Integer> dims;
250250
private final Parameter<VectorSimilarity> similarity;
251251

@@ -454,7 +454,13 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
454454
}
455455

456456
public enum ElementType {
457-
BYTE, FLOAT, BIT;
457+
BYTE,
458+
FLOAT,
459+
BIT;
460+
461+
public static ElementType fromString(String name) {
462+
return valueOf(name.toUpperCase(Locale.ROOT));
463+
}
458464

459465
@Override
460466
public String toString() {
@@ -475,15 +481,15 @@ public String toString() {
475481
ElementType.BIT
476482
);
477483

478-
private static final Map<ElementType, Element> elements = Map.of(
479-
ElementType.BYTE,
480-
BYTE_ELEMENT,
481-
ElementType.FLOAT,
482-
FLOAT_ELEMENT,
483-
ElementType.BIT,
484-
BIT_ELEMENT);
484+
public abstract static class Element {
485485

486-
public static abstract class Element {
486+
public static Element getElement(ElementType elementType) {
487+
return switch (elementType) {
488+
case FLOAT -> FLOAT_ELEMENT;
489+
case BYTE -> BYTE_ELEMENT;
490+
case BIT -> BIT_ELEMENT;
491+
};
492+
}
487493

488494
/**
489495
* Checks the input {@code vector} is one of the {@code possibleTypes},
@@ -495,7 +501,7 @@ public static ElementType checkValidVector(float[] vector, ElementType... possib
495501
// assume the types are in order of specificity
496502
StringBuilder[] errors = new StringBuilder[possibleTypes.length];
497503
for (int i = 0; i < possibleTypes.length; i++) {
498-
StringBuilder error = elements.get(possibleTypes[i]).checkVectorErrors(vector);
504+
StringBuilder error = getElement(possibleTypes[i]).checkVectorErrors(vector);
499505
if (error == null) {
500506
// this one works - use it
501507
return possibleTypes[i];
@@ -515,28 +521,28 @@ public static ElementType checkValidVector(float[] vector, ElementType... possib
515521
throw new IllegalArgumentException(FloatElement.appendErrorElements(message, vector).toString());
516522
}
517523

518-
abstract ElementType elementType();
524+
public abstract ElementType elementType();
519525

520-
abstract void writeValue(ByteBuffer byteBuffer, float value);
526+
public abstract void writeValue(ByteBuffer byteBuffer, float value);
521527

522-
abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException;
528+
public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException;
523529

524530
abstract IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext);
525531

526532
abstract void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException;
527533

528-
abstract VectorData parseKnnVector(
534+
public abstract VectorData parseKnnVector(
529535
DocumentParserContext context,
530536
int dims,
531537
IntBooleanConsumer dimChecker,
532538
VectorSimilarity similarity
533539
) throws IOException;
534540

535-
abstract int getNumBytes(int dimensions);
541+
public abstract int getNumBytes(int dimensions);
536542

537-
abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
543+
public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
538544

539-
void checkVectorBounds(float[] vector) {
545+
public void checkVectorBounds(float[] vector) {
540546
StringBuilder errors = checkVectorErrors(vector);
541547
if (errors != null) {
542548
throw new IllegalArgumentException(FloatElement.appendErrorElements(errors, vector).toString());
@@ -553,15 +559,17 @@ abstract void checkVectorMagnitude(
553559
float squaredMagnitude
554560
);
555561

556-
void checkDimensions(Integer dvDims, int qvDims) {
562+
public abstract double computeSquaredMagnitude(VectorData vectorData);
563+
564+
public void checkDimensions(Integer dvDims, int qvDims) {
557565
if (dvDims != null && dvDims != qvDims) {
558566
throw new IllegalArgumentException(
559567
"The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
560568
);
561569
}
562570
}
563571

564-
int parseDimensionCount(DocumentParserContext context) throws IOException {
572+
public int parseDimensionCount(DocumentParserContext context) throws IOException {
565573
int index = 0;
566574
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
567575
index++;
@@ -604,14 +612,12 @@ StringBuilder checkNanAndInfinite(float[] vector) {
604612

605613
return errorBuilder;
606614
}
607-
608-
public abstract double computeSquaredMagnitude(VectorData vectorData);
609615
}
610616

611617
private static class ByteElement extends Element {
612618

613619
@Override
614-
ElementType elementType() {
620+
public ElementType elementType() {
615621
return ElementType.BYTE;
616622
}
617623

@@ -863,7 +869,7 @@ static UnaryOperator<StringBuilder> errorElementsAppender(byte[] vector) {
863869
private static class FloatElement extends Element {
864870

865871
@Override
866-
ElementType elementType() {
872+
public ElementType elementType() {
867873
return ElementType.FLOAT;
868874
}
869875

@@ -1048,7 +1054,7 @@ static UnaryOperator<StringBuilder> errorElementsAppender(float[] vector) {
10481054
private static class BitElement extends ByteElement {
10491055

10501056
@Override
1051-
ElementType elementType() {
1057+
public ElementType elementType() {
10521058
return ElementType.BIT;
10531059
}
10541060

@@ -2220,7 +2226,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
22202226
);
22212227

22222228
public static final class DenseVectorFieldType extends SimpleMappedFieldType {
2223-
private final ElementType elementType;
22242229
private final Element element;
22252230
private final Integer dims;
22262231
private final boolean indexed;
@@ -2241,16 +2246,13 @@ public DenseVectorFieldType(
22412246
boolean isSyntheticSource
22422247
) {
22432248
super(name, indexed, false, indexed == false, TextSearchInfo.NONE, meta);
2244-
this.elementType = elementType;
2249+
this.element = Element.getElement(elementType);
22452250
this.dims = dims;
22462251
this.indexed = indexed;
22472252
this.similarity = similarity;
22482253
this.indexVersionCreated = indexVersionCreated;
22492254
this.indexOptions = indexOptions;
22502255
this.isSyntheticSource = isSyntheticSource;
2251-
2252-
this.element = elements.get(elementType);
2253-
assert this.element != null;
22542256
}
22552257

22562258
@Override
@@ -2307,13 +2309,17 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity)
23072309
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
23082310
);
23092311
}
2310-
Query knnQuery = switch (elementType) {
2312+
Query knnQuery = switch (element.elementType()) {
23112313
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
23122314
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
23132315
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
23142316
};
23152317
if (vectorSimilarity != null) {
2316-
knnQuery = new VectorSimilarityQuery(knnQuery, vectorSimilarity, similarity.score(vectorSimilarity, elementType, dims));
2318+
knnQuery = new VectorSimilarityQuery(
2319+
knnQuery,
2320+
vectorSimilarity,
2321+
similarity.score(vectorSimilarity, element.elementType(), dims)
2322+
);
23172323
}
23182324
return knnQuery;
23192325
}
@@ -2323,15 +2329,15 @@ public boolean isNormalized() {
23232329
}
23242330

23252331
private Query createExactKnnBitQuery(byte[] queryVector) {
2326-
elements.get(elementType).checkDimensions(dims, queryVector.length);
2332+
element.checkDimensions(dims, queryVector.length);
23272333
return new DenseVectorQuery.Bytes(queryVector, name());
23282334
}
23292335

23302336
private Query createExactKnnByteQuery(byte[] queryVector) {
2331-
elements.get(elementType).checkDimensions(dims, queryVector.length);
2337+
element.checkDimensions(dims, queryVector.length);
23322338
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
23332339
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
2334-
elements.get(elementType).checkVectorMagnitude(similarity, ByteElement.errorElementsAppender(queryVector), squaredMagnitude);
2340+
element.checkVectorMagnitude(similarity, ByteElement.errorElementsAppender(queryVector), squaredMagnitude);
23352341
}
23362342
return new DenseVectorQuery.Bytes(queryVector, name());
23372343
}
@@ -2449,7 +2455,7 @@ private Query createKnnBitQuery(
24492455
knnQuery = new VectorSimilarityQuery(
24502456
knnQuery,
24512457
similarityThreshold,
2452-
similarity.score(similarityThreshold, elementType, dims)
2458+
similarity.score(similarityThreshold, element.elementType(), dims)
24532459
);
24542460
}
24552461
return knnQuery;
@@ -2493,7 +2499,7 @@ private Query createKnnByteQuery(
24932499
knnQuery = new VectorSimilarityQuery(
24942500
knnQuery,
24952501
similarityThreshold,
2496-
similarity.score(similarityThreshold, elementType, dims)
2502+
similarity.score(similarityThreshold, element.elementType(), dims)
24972503
);
24982504
}
24992505
return knnQuery;
@@ -2609,7 +2615,7 @@ private Query createKnnFloatQuery(
26092615
knnQuery = new VectorSimilarityQuery(
26102616
knnQuery,
26112617
similarityThreshold,
2612-
similarity.score(similarityThreshold, elementType, dims)
2618+
similarity.score(similarityThreshold, element.elementType(), dims)
26132619
);
26142620
}
26152621
return knnQuery;
@@ -2624,7 +2630,7 @@ int getVectorDimensions() {
26242630
}
26252631

26262632
public ElementType getElementType() {
2627-
return elementType;
2633+
return element.elementType();
26282634
}
26292635

26302636
public DenseVectorIndexOptions getIndexOptions() {
@@ -2633,7 +2639,7 @@ public DenseVectorIndexOptions getIndexOptions() {
26332639

26342640
@Override
26352641
public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
2636-
if (elementType == ElementType.BIT) {
2642+
if (element.elementType() == ElementType.BIT) {
26372643
// Just float and byte dense vector support for now
26382644
return null;
26392645
}
@@ -2648,7 +2654,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
26482654
}
26492655

26502656
if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) {
2651-
return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(name(), dims, indexVersionCreated, elementType);
2657+
return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(name(), dims, indexVersionCreated, element.elementType());
26522658
}
26532659

26542660
BlockSourceReader.LeafIteratorLookup lookup = BlockSourceReader.lookupMatchingAll();
@@ -2838,9 +2844,9 @@ private static DenseVectorIndexOptions parseIndexOptions(String fieldName, Objec
28382844
public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) {
28392845
final KnnVectorsFormat format;
28402846
if (indexOptions == null) {
2841-
format = fieldType().elementType == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat;
2847+
format = fieldType().element.elementType() == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat;
28422848
} else {
2843-
format = indexOptions.getVectorsFormat(fieldType().elementType);
2849+
format = indexOptions.getVectorsFormat(fieldType().element.elementType());
28442850
}
28452851
// It's legal to reuse the same format name as this is the same on-disk format.
28462852
return new KnnVectorsFormat(format.getName()) {
@@ -3047,7 +3053,7 @@ public void write(XContentBuilder b) throws IOException {
30473053
if (indexCreatedVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)) {
30483054
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
30493055
}
3050-
int dims = fieldType().elementType == ElementType.BIT ? fieldType().dims / Byte.SIZE : fieldType().dims;
3056+
int dims = fieldType().element.elementType() == ElementType.BIT ? fieldType().dims / Byte.SIZE : fieldType().dims;
30513057
for (int dim = 0; dim < dims; dim++) {
30523058
fieldType().element.readAndWriteValue(byteBuffer, b);
30533059
}

server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.elasticsearch.ExceptionsHelper;
1313
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
1415
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
1516
import org.elasticsearch.script.field.vectors.DenseVector;
1617
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
@@ -66,7 +67,7 @@ public ByteDenseVectorFunction(
6667
ElementType... allowedTypes
6768
) {
6869
super(scoreScript, field);
69-
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
70+
field.getElement().checkDimensions(field.get().getDims(), queryVector.size());
7071
float[] floatValues = new float[queryVector.size()];
7172
double queryMagnitude = 0;
7273
for (int i = 0; i < queryVector.size(); i++) {
@@ -76,7 +77,7 @@ public ByteDenseVectorFunction(
7677
}
7778
queryMagnitude = Math.sqrt(queryMagnitude);
7879

79-
switch (ElementType.checkValidVector(floatValues, allowedTypes)) {
80+
switch (Element.checkValidVector(floatValues, allowedTypes)) {
8081
case FLOAT:
8182
byteQueryVector = null;
8283
floatQueryVector = floatValues;
@@ -149,7 +150,7 @@ public FloatDenseVectorFunction(
149150
queryMagnitude += value * value;
150151
}
151152
queryMagnitude = Math.sqrt(queryMagnitude);
152-
field.getElementType().checkVectorBounds(this.queryVector);
153+
field.getElement().checkVectorBounds(this.queryVector);
153154

154155
if (normalizeQuery) {
155156
for (int dim = 0; dim < this.queryVector.length; dim++) {

server/src/main/java/org/elasticsearch/script/field/vectors/DenseVectorDocValuesField.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import java.util.Iterator;
1818

19+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
1920
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
2021

2122
public abstract class DenseVectorDocValuesField extends AbstractScriptFieldFactory<DenseVector>
@@ -40,6 +41,10 @@ public ElementType getElementType() {
4041
return elementType;
4142
}
4243

44+
public Element getElement() {
45+
return Element.getElement(elementType);
46+
}
47+
4348
@Override
4449
public int size() {
4550
return isEmpty() ? 0 : 1;

server/src/main/java/org/elasticsearch/script/field/vectors/RankVectorsDocValuesField.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99

1010
package org.elasticsearch.script.field.vectors;
1111

12+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
13+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.Element;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
1215
import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues;
1316
import org.elasticsearch.script.field.AbstractScriptFieldFactory;
1417
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
1518
import org.elasticsearch.script.field.Field;
1619

1720
import java.util.Iterator;
1821

19-
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
20-
2122
public abstract class RankVectorsDocValuesField extends AbstractScriptFieldFactory<RankVectors>
2223
implements
2324
Field<RankVectors>,
@@ -40,6 +41,10 @@ public ElementType getElementType() {
4041
return elementType;
4142
}
4243

44+
public Element getElement() {
45+
return DenseVectorFieldMapper.Element.getElement(elementType);
46+
}
47+
4348
/**
4449
* Get the DenseVector for a document if one exists, DenseVector.EMPTY otherwise
4550
*/

server/src/main/java/org/elasticsearch/search/vectors/VectorData.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public byte[] asByteVector() {
5656
if (byteVector != null) {
5757
return byteVector;
5858
}
59-
DenseVectorFieldMapper.ElementType.BYTE.checkVectorBounds(floatVector);
59+
DenseVectorFieldMapper.BYTE_ELEMENT.checkVectorBounds(floatVector);
6060
byte[] vec = new byte[floatVector.length];
6161
for (int i = 0; i < floatVector.length; i++) {
6262
vec[i] = (byte) floatVector[i];

0 commit comments

Comments
 (0)