@@ -245,7 +245,7 @@ public static class Builder extends FieldMapper.Builder {
245245 throw new MapperParsingException ("invalid element_type [" + o + "]; available types are " + namesToElementType .keySet ());
246246 }
247247 return elementType ;
248- }, m -> toType (m ).fieldType ().elementType , XContentBuilder ::field , Objects ::toString );
248+ }, m -> toType (m ).fieldType ().element . elementType () , XContentBuilder ::field , Objects ::toString );
249249 private final Parameter <Integer > dims ;
250250 private final Parameter <VectorSimilarity > similarity ;
251251
@@ -454,7 +454,13 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
454454 }
455455
456456 public enum ElementType {
457- BYTE , FLOAT , BIT ;
457+ BYTE ,
458+ FLOAT ,
459+ BIT ;
460+
461+ public static ElementType fromString (String name ) {
462+ return valueOf (name .toUpperCase (Locale .ROOT ));
463+ }
458464
459465 @ Override
460466 public String toString () {
@@ -475,15 +481,15 @@ public String toString() {
475481 ElementType .BIT
476482 );
477483
478- private static final Map <ElementType , Element > elements = Map .of (
479- ElementType .BYTE ,
480- BYTE_ELEMENT ,
481- ElementType .FLOAT ,
482- FLOAT_ELEMENT ,
483- ElementType .BIT ,
484- BIT_ELEMENT );
484+ public abstract static class Element {
485485
486- public static abstract class Element {
486+ public static Element getElement (ElementType elementType ) {
487+ return switch (elementType ) {
488+ case FLOAT -> FLOAT_ELEMENT ;
489+ case BYTE -> BYTE_ELEMENT ;
490+ case BIT -> BIT_ELEMENT ;
491+ };
492+ }
487493
488494 /**
489495 * Checks the input {@code vector} is one of the {@code possibleTypes},
@@ -495,7 +501,7 @@ public static ElementType checkValidVector(float[] vector, ElementType... possib
495501 // assume the types are in order of specificity
496502 StringBuilder [] errors = new StringBuilder [possibleTypes .length ];
497503 for (int i = 0 ; i < possibleTypes .length ; i ++) {
498- StringBuilder error = elements . get (possibleTypes [i ]).checkVectorErrors (vector );
504+ StringBuilder error = getElement (possibleTypes [i ]).checkVectorErrors (vector );
499505 if (error == null ) {
500506 // this one works - use it
501507 return possibleTypes [i ];
@@ -515,28 +521,28 @@ public static ElementType checkValidVector(float[] vector, ElementType... possib
515521 throw new IllegalArgumentException (FloatElement .appendErrorElements (message , vector ).toString ());
516522 }
517523
518- abstract ElementType elementType ();
524+ public abstract ElementType elementType ();
519525
520- abstract void writeValue (ByteBuffer byteBuffer , float value );
526+ public abstract void writeValue (ByteBuffer byteBuffer , float value );
521527
522- abstract void readAndWriteValue (ByteBuffer byteBuffer , XContentBuilder b ) throws IOException ;
528+ public abstract void readAndWriteValue (ByteBuffer byteBuffer , XContentBuilder b ) throws IOException ;
523529
524530 abstract IndexFieldData .Builder fielddataBuilder (DenseVectorFieldType denseVectorFieldType , FieldDataContext fieldDataContext );
525531
526532 abstract void parseKnnVectorAndIndex (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException ;
527533
528- abstract VectorData parseKnnVector (
534+ public abstract VectorData parseKnnVector (
529535 DocumentParserContext context ,
530536 int dims ,
531537 IntBooleanConsumer dimChecker ,
532538 VectorSimilarity similarity
533539 ) throws IOException ;
534540
535- abstract int getNumBytes (int dimensions );
541+ public abstract int getNumBytes (int dimensions );
536542
537- abstract ByteBuffer createByteBuffer (IndexVersion indexVersion , int numBytes );
543+ public abstract ByteBuffer createByteBuffer (IndexVersion indexVersion , int numBytes );
538544
539- void checkVectorBounds (float [] vector ) {
545+ public void checkVectorBounds (float [] vector ) {
540546 StringBuilder errors = checkVectorErrors (vector );
541547 if (errors != null ) {
542548 throw new IllegalArgumentException (FloatElement .appendErrorElements (errors , vector ).toString ());
@@ -553,15 +559,17 @@ abstract void checkVectorMagnitude(
553559 float squaredMagnitude
554560 );
555561
556- void checkDimensions (Integer dvDims , int qvDims ) {
562+ public abstract double computeSquaredMagnitude (VectorData vectorData );
563+
564+ public void checkDimensions (Integer dvDims , int qvDims ) {
557565 if (dvDims != null && dvDims != qvDims ) {
558566 throw new IllegalArgumentException (
559567 "The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
560568 );
561569 }
562570 }
563571
564- int parseDimensionCount (DocumentParserContext context ) throws IOException {
572+ public int parseDimensionCount (DocumentParserContext context ) throws IOException {
565573 int index = 0 ;
566574 for (Token token = context .parser ().nextToken (); token != Token .END_ARRAY ; token = context .parser ().nextToken ()) {
567575 index ++;
@@ -604,14 +612,12 @@ StringBuilder checkNanAndInfinite(float[] vector) {
604612
605613 return errorBuilder ;
606614 }
607-
608- public abstract double computeSquaredMagnitude (VectorData vectorData );
609615 }
610616
611617 private static class ByteElement extends Element {
612618
613619 @ Override
614- ElementType elementType () {
620+ public ElementType elementType () {
615621 return ElementType .BYTE ;
616622 }
617623
@@ -863,7 +869,7 @@ static UnaryOperator<StringBuilder> errorElementsAppender(byte[] vector) {
863869 private static class FloatElement extends Element {
864870
865871 @ Override
866- ElementType elementType () {
872+ public ElementType elementType () {
867873 return ElementType .FLOAT ;
868874 }
869875
@@ -1048,7 +1054,7 @@ static UnaryOperator<StringBuilder> errorElementsAppender(float[] vector) {
10481054 private static class BitElement extends ByteElement {
10491055
10501056 @ Override
1051- ElementType elementType () {
1057+ public ElementType elementType () {
10521058 return ElementType .BIT ;
10531059 }
10541060
@@ -2220,7 +2226,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
22202226 );
22212227
22222228 public static final class DenseVectorFieldType extends SimpleMappedFieldType {
2223- private final ElementType elementType ;
22242229 private final Element element ;
22252230 private final Integer dims ;
22262231 private final boolean indexed ;
@@ -2241,16 +2246,13 @@ public DenseVectorFieldType(
22412246 boolean isSyntheticSource
22422247 ) {
22432248 super (name , indexed , false , indexed == false , TextSearchInfo .NONE , meta );
2244- this .elementType = elementType ;
2249+ this .element = Element . getElement ( elementType ) ;
22452250 this .dims = dims ;
22462251 this .indexed = indexed ;
22472252 this .similarity = similarity ;
22482253 this .indexVersionCreated = indexVersionCreated ;
22492254 this .indexOptions = indexOptions ;
22502255 this .isSyntheticSource = isSyntheticSource ;
2251-
2252- this .element = elements .get (elementType );
2253- assert this .element != null ;
22542256 }
22552257
22562258 @ Override
@@ -2307,13 +2309,17 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity)
23072309 "to perform knn search on field [" + name () + "], its mapping must have [index] set to [true]"
23082310 );
23092311 }
2310- Query knnQuery = switch (elementType ) {
2312+ Query knnQuery = switch (element . elementType () ) {
23112313 case BYTE -> createExactKnnByteQuery (queryVector .asByteVector ());
23122314 case FLOAT -> createExactKnnFloatQuery (queryVector .asFloatVector ());
23132315 case BIT -> createExactKnnBitQuery (queryVector .asByteVector ());
23142316 };
23152317 if (vectorSimilarity != null ) {
2316- knnQuery = new VectorSimilarityQuery (knnQuery , vectorSimilarity , similarity .score (vectorSimilarity , elementType , dims ));
2318+ knnQuery = new VectorSimilarityQuery (
2319+ knnQuery ,
2320+ vectorSimilarity ,
2321+ similarity .score (vectorSimilarity , element .elementType (), dims )
2322+ );
23172323 }
23182324 return knnQuery ;
23192325 }
@@ -2323,15 +2329,15 @@ public boolean isNormalized() {
23232329 }
23242330
23252331 private Query createExactKnnBitQuery (byte [] queryVector ) {
2326- elements . get ( elementType ) .checkDimensions (dims , queryVector .length );
2332+ element .checkDimensions (dims , queryVector .length );
23272333 return new DenseVectorQuery .Bytes (queryVector , name ());
23282334 }
23292335
23302336 private Query createExactKnnByteQuery (byte [] queryVector ) {
2331- elements . get ( elementType ) .checkDimensions (dims , queryVector .length );
2337+ element .checkDimensions (dims , queryVector .length );
23322338 if (similarity == VectorSimilarity .DOT_PRODUCT || similarity == VectorSimilarity .COSINE ) {
23332339 float squaredMagnitude = VectorUtil .dotProduct (queryVector , queryVector );
2334- elements . get ( elementType ) .checkVectorMagnitude (similarity , ByteElement .errorElementsAppender (queryVector ), squaredMagnitude );
2340+ element .checkVectorMagnitude (similarity , ByteElement .errorElementsAppender (queryVector ), squaredMagnitude );
23352341 }
23362342 return new DenseVectorQuery .Bytes (queryVector , name ());
23372343 }
@@ -2449,7 +2455,7 @@ private Query createKnnBitQuery(
24492455 knnQuery = new VectorSimilarityQuery (
24502456 knnQuery ,
24512457 similarityThreshold ,
2452- similarity .score (similarityThreshold , elementType , dims )
2458+ similarity .score (similarityThreshold , element . elementType () , dims )
24532459 );
24542460 }
24552461 return knnQuery ;
@@ -2493,7 +2499,7 @@ private Query createKnnByteQuery(
24932499 knnQuery = new VectorSimilarityQuery (
24942500 knnQuery ,
24952501 similarityThreshold ,
2496- similarity .score (similarityThreshold , elementType , dims )
2502+ similarity .score (similarityThreshold , element . elementType () , dims )
24972503 );
24982504 }
24992505 return knnQuery ;
@@ -2609,7 +2615,7 @@ private Query createKnnFloatQuery(
26092615 knnQuery = new VectorSimilarityQuery (
26102616 knnQuery ,
26112617 similarityThreshold ,
2612- similarity .score (similarityThreshold , elementType , dims )
2618+ similarity .score (similarityThreshold , element . elementType () , dims )
26132619 );
26142620 }
26152621 return knnQuery ;
@@ -2624,7 +2630,7 @@ int getVectorDimensions() {
26242630 }
26252631
26262632 public ElementType getElementType () {
2627- return elementType ;
2633+ return element . elementType () ;
26282634 }
26292635
26302636 public DenseVectorIndexOptions getIndexOptions () {
@@ -2633,7 +2639,7 @@ public DenseVectorIndexOptions getIndexOptions() {
26332639
26342640 @ Override
26352641 public BlockLoader blockLoader (MappedFieldType .BlockLoaderContext blContext ) {
2636- if (elementType == ElementType .BIT ) {
2642+ if (element . elementType () == ElementType .BIT ) {
26372643 // Just float and byte dense vector support for now
26382644 return null ;
26392645 }
@@ -2648,7 +2654,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
26482654 }
26492655
26502656 if (hasDocValues () && (blContext .fieldExtractPreference () != FieldExtractPreference .STORED || isSyntheticSource )) {
2651- return new BlockDocValuesReader .DenseVectorFromBinaryBlockLoader (name (), dims , indexVersionCreated , elementType );
2657+ return new BlockDocValuesReader .DenseVectorFromBinaryBlockLoader (name (), dims , indexVersionCreated , element . elementType () );
26522658 }
26532659
26542660 BlockSourceReader .LeafIteratorLookup lookup = BlockSourceReader .lookupMatchingAll ();
@@ -2838,9 +2844,9 @@ private static DenseVectorIndexOptions parseIndexOptions(String fieldName, Objec
28382844 public KnnVectorsFormat getKnnVectorsFormatForField (KnnVectorsFormat defaultFormat ) {
28392845 final KnnVectorsFormat format ;
28402846 if (indexOptions == null ) {
2841- format = fieldType ().elementType == ElementType .BIT ? new ES815HnswBitVectorsFormat () : defaultFormat ;
2847+ format = fieldType ().element . elementType () == ElementType .BIT ? new ES815HnswBitVectorsFormat () : defaultFormat ;
28422848 } else {
2843- format = indexOptions .getVectorsFormat (fieldType ().elementType );
2849+ format = indexOptions .getVectorsFormat (fieldType ().element . elementType () );
28442850 }
28452851 // It's legal to reuse the same format name as this is the same on-disk format.
28462852 return new KnnVectorsFormat (format .getName ()) {
@@ -3047,7 +3053,7 @@ public void write(XContentBuilder b) throws IOException {
30473053 if (indexCreatedVersion .onOrAfter (LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION )) {
30483054 byteBuffer .order (ByteOrder .LITTLE_ENDIAN );
30493055 }
3050- int dims = fieldType ().elementType == ElementType .BIT ? fieldType ().dims / Byte .SIZE : fieldType ().dims ;
3056+ int dims = fieldType ().element . elementType () == ElementType .BIT ? fieldType ().dims / Byte .SIZE : fieldType ().dims ;
30513057 for (int dim = 0 ; dim < dims ; dim ++) {
30523058 fieldType ().element .readAndWriteValue (byteBuffer , b );
30533059 }
0 commit comments