@@ -416,13 +416,18 @@ public double computeSquaredMagnitude(VectorData vectorData) {
416416 return VectorUtil .dotProduct (vectorData .asByteVector (), vectorData .asByteVector ());
417417 }
418418
419- private VectorData parseVectorArray (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
419+ private VectorData parseVectorArray (
420+ DocumentParserContext context ,
421+ int dims ,
422+ IntBooleanConsumer dimChecker ,
423+ VectorSimilarity similarity
424+ ) throws IOException {
420425 int index = 0 ;
421- byte [] vector = new byte [fieldMapper . fieldType (). dims ];
426+ byte [] vector = new byte [dims ];
422427 float squaredMagnitude = 0 ;
423428 for (XContentParser .Token token = context .parser ().nextToken (); token != Token .END_ARRAY ; token = context .parser ()
424429 .nextToken ()) {
425- fieldMapper . checkDimensionExceeded (index , context );
430+ dimChecker . accept (index , false );
426431 ensureExpectedToken (Token .VALUE_NUMBER , token , context .parser ());
427432 final int value ;
428433 if (context .parser ().numberType () != XContentParser .NumberType .INT ) {
@@ -460,30 +465,31 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
460465 vector [index ++] = (byte ) value ;
461466 squaredMagnitude += value * value ;
462467 }
463- fieldMapper . checkDimensionMatches (index , context );
464- checkVectorMagnitude (fieldMapper . fieldType (). similarity , errorByteElementsAppender (vector ), squaredMagnitude );
468+ dimChecker . accept (index , true );
469+ checkVectorMagnitude (similarity , errorByteElementsAppender (vector ), squaredMagnitude );
465470 return VectorData .fromBytes (vector );
466471 }
467472
468- private VectorData parseHexEncodedVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
473+ private VectorData parseHexEncodedVector (
474+ DocumentParserContext context ,
475+ IntBooleanConsumer dimChecker ,
476+ VectorSimilarity similarity
477+ ) throws IOException {
469478 byte [] decodedVector = HexFormat .of ().parseHex (context .parser ().text ());
470- fieldMapper . checkDimensionMatches (decodedVector .length , context );
479+ dimChecker . accept (decodedVector .length , true );
471480 VectorData vectorData = VectorData .fromBytes (decodedVector );
472481 double squaredMagnitude = computeSquaredMagnitude (vectorData );
473- checkVectorMagnitude (
474- fieldMapper .fieldType ().similarity ,
475- errorByteElementsAppender (decodedVector ),
476- (float ) squaredMagnitude
477- );
482+ checkVectorMagnitude (similarity , errorByteElementsAppender (decodedVector ), (float ) squaredMagnitude );
478483 return vectorData ;
479484 }
480485
481486 @ Override
482- VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
487+ VectorData parseKnnVector (DocumentParserContext context , int dims , IntBooleanConsumer dimChecker , VectorSimilarity similarity )
488+ throws IOException {
483489 XContentParser .Token token = context .parser ().currentToken ();
484490 return switch (token ) {
485- case START_ARRAY -> parseVectorArray (context , fieldMapper );
486- case VALUE_STRING -> parseHexEncodedVector (context , fieldMapper );
491+ case START_ARRAY -> parseVectorArray (context , dims , dimChecker , similarity );
492+ case VALUE_STRING -> parseHexEncodedVector (context , dimChecker , similarity );
487493 default -> throw new ParsingException (
488494 context .parser ().getTokenLocation (),
489495 format ("Unsupported type [%s] for provided value [%s]" , token , context .parser ().text ())
@@ -493,7 +499,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
493499
494500 @ Override
495501 public void parseKnnVectorAndIndex (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
496- VectorData vectorData = parseKnnVector (context , fieldMapper );
502+ VectorData vectorData = parseKnnVector (context , fieldMapper .fieldType ().dims , (i , end ) -> {
503+ if (end ) {
504+ fieldMapper .checkDimensionMatches (i , context );
505+ } else {
506+ fieldMapper .checkDimensionExceeded (i , context );
507+ }
508+ }, fieldMapper .fieldType ().similarity );
497509 Field field = createKnnVectorField (
498510 fieldMapper .fieldType ().name (),
499511 vectorData .asByteVector (),
@@ -677,21 +689,22 @@ && isNotUnitVector(squaredMagnitude)) {
677689 }
678690
679691 @ Override
680- VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
692+ VectorData parseKnnVector (DocumentParserContext context , int dims , IntBooleanConsumer dimChecker , VectorSimilarity similarity )
693+ throws IOException {
681694 int index = 0 ;
682695 float squaredMagnitude = 0 ;
683- float [] vector = new float [fieldMapper . fieldType (). dims ];
696+ float [] vector = new float [dims ];
684697 for (Token token = context .parser ().nextToken (); token != Token .END_ARRAY ; token = context .parser ().nextToken ()) {
685- fieldMapper . checkDimensionExceeded (index , context );
698+ dimChecker . accept (index , false );
686699 ensureExpectedToken (Token .VALUE_NUMBER , token , context .parser ());
687700 float value = context .parser ().floatValue (true );
688701 vector [index ] = value ;
689702 squaredMagnitude += value * value ;
690703 index ++;
691704 }
692- fieldMapper . checkDimensionMatches (index , context );
705+ dimChecker . accept (index , true );
693706 checkVectorBounds (vector );
694- checkVectorMagnitude (fieldMapper . fieldType (). similarity , errorFloatElementsAppender (vector ), squaredMagnitude );
707+ checkVectorMagnitude (similarity , errorFloatElementsAppender (vector ), squaredMagnitude );
695708 return VectorData .fromFloats (vector );
696709 }
697710
@@ -816,12 +829,17 @@ public double computeSquaredMagnitude(VectorData vectorData) {
816829 return count ;
817830 }
818831
819- private VectorData parseVectorArray (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
832+ private VectorData parseVectorArray (
833+ DocumentParserContext context ,
834+ int dims ,
835+ IntBooleanConsumer dimChecker ,
836+ VectorSimilarity similarity
837+ ) throws IOException {
820838 int index = 0 ;
821- byte [] vector = new byte [fieldMapper . fieldType (). dims / Byte .SIZE ];
839+ byte [] vector = new byte [dims / Byte .SIZE ];
822840 for (XContentParser .Token token = context .parser ().nextToken (); token != Token .END_ARRAY ; token = context .parser ()
823841 .nextToken ()) {
824- fieldMapper . checkDimensionExceeded (index , context );
842+ dimChecker . accept (index * Byte . SIZE , false );
825843 ensureExpectedToken (Token .VALUE_NUMBER , token , context .parser ());
826844 final int value ;
827845 if (context .parser ().numberType () != XContentParser .NumberType .INT ) {
@@ -856,35 +874,25 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
856874 + "];"
857875 );
858876 }
859- if (index >= vector .length ) {
860- throw new IllegalArgumentException (
861- "The number of dimensions for field ["
862- + fieldMapper .fieldType ().name ()
863- + "] should be ["
864- + fieldMapper .fieldType ().dims
865- + "] but found ["
866- + (index + 1 ) * Byte .SIZE
867- + "]"
868- );
869- }
870877 vector [index ++] = (byte ) value ;
871878 }
872- fieldMapper . checkDimensionMatches (index * Byte .SIZE , context );
879+ dimChecker . accept (index * Byte .SIZE , true );
873880 return VectorData .fromBytes (vector );
874881 }
875882
876- private VectorData parseHexEncodedVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
883+ private VectorData parseHexEncodedVector (DocumentParserContext context , IntBooleanConsumer dimChecker ) throws IOException {
877884 byte [] decodedVector = HexFormat .of ().parseHex (context .parser ().text ());
878- fieldMapper . checkDimensionMatches (decodedVector .length * Byte .SIZE , context );
885+ dimChecker . accept (decodedVector .length * Byte .SIZE , true );
879886 return VectorData .fromBytes (decodedVector );
880887 }
881888
882889 @ Override
883- VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
890+ VectorData parseKnnVector (DocumentParserContext context , int dims , IntBooleanConsumer dimChecker , VectorSimilarity similarity )
891+ throws IOException {
884892 XContentParser .Token token = context .parser ().currentToken ();
885893 return switch (token ) {
886- case START_ARRAY -> parseVectorArray (context , fieldMapper );
887- case VALUE_STRING -> parseHexEncodedVector (context , fieldMapper );
894+ case START_ARRAY -> parseVectorArray (context , dims , dimChecker , similarity );
895+ case VALUE_STRING -> parseHexEncodedVector (context , dimChecker );
888896 default -> throw new ParsingException (
889897 context .parser ().getTokenLocation (),
890898 format ("Unsupported type [%s] for provided value [%s]" , token , context .parser ().text ())
@@ -894,7 +902,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
894902
895903 @ Override
896904 public void parseKnnVectorAndIndex (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
897- VectorData vectorData = parseKnnVector (context , fieldMapper );
905+ VectorData vectorData = parseKnnVector (context , fieldMapper .fieldType ().dims , (i , end ) -> {
906+ if (end ) {
907+ fieldMapper .checkDimensionMatches (i , context );
908+ } else {
909+ fieldMapper .checkDimensionExceeded (i , context );
910+ }
911+ }, fieldMapper .fieldType ().similarity );
898912 Field field = createKnnVectorField (
899913 fieldMapper .fieldType ().name (),
900914 vectorData .asByteVector (),
@@ -958,7 +972,12 @@ public void checkDimensions(Integer dvDims, int qvDims) {
958972
959973 abstract void parseKnnVectorAndIndex (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException ;
960974
961- abstract VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException ;
975+ abstract VectorData parseKnnVector (
976+ DocumentParserContext context ,
977+ int dims ,
978+ IntBooleanConsumer dimChecker ,
979+ VectorSimilarity similarity
980+ ) throws IOException ;
962981
963982 abstract int getNumBytes (int dimensions );
964983
@@ -2180,7 +2199,13 @@ private void parseBinaryDocValuesVectorAndIndex(DocumentParserContext context) t
21802199 : elementType .getNumBytes (dims );
21812200
21822201 ByteBuffer byteBuffer = elementType .createByteBuffer (indexCreatedVersion , numBytes );
2183- VectorData vectorData = elementType .parseKnnVector (context , this );
2202+ VectorData vectorData = elementType .parseKnnVector (context , dims , (i , b ) -> {
2203+ if (b ) {
2204+ checkDimensionMatches (i , context );
2205+ } else {
2206+ checkDimensionExceeded (i , context );
2207+ }
2208+ }, fieldType ().similarity );
21842209 vectorData .addToBuffer (byteBuffer );
21852210 if (indexCreatedVersion .onOrAfter (MAGNITUDE_STORED_INDEX_VERSION )) {
21862211 // encode vector magnitude at the end
@@ -2433,4 +2458,11 @@ public String fieldName() {
24332458 return fullPath ();
24342459 }
24352460 }
2461+
2462+ /**
2463+ * @FunctionalInterface for a function that takes a int and boolean
2464+ */
2465+ interface IntBooleanConsumer {
2466+ void accept (int value , boolean isComplete );
2467+ }
24362468}
0 commit comments