@@ -415,13 +415,18 @@ public double computeSquaredMagnitude(VectorData vectorData) {
415415 return VectorUtil .dotProduct (vectorData .asByteVector (), vectorData .asByteVector ());
416416 }
417417
418- private VectorData parseVectorArray (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
418+ private VectorData parseVectorArray (
419+ DocumentParserContext context ,
420+ int dims ,
421+ IntBooleanConsumer dimChecker ,
422+ VectorSimilarity similarity
423+ ) throws IOException {
419424 int index = 0 ;
420- byte [] vector = new byte [fieldMapper . fieldType (). dims ];
425+ byte [] vector = new byte [dims ];
421426 float squaredMagnitude = 0 ;
422427 for (XContentParser .Token token = context .parser ().nextToken (); token != Token .END_ARRAY ; token = context .parser ()
423428 .nextToken ()) {
424- fieldMapper . checkDimensionExceeded (index , context );
429+ dimChecker . accept (index , false );
425430 ensureExpectedToken (Token .VALUE_NUMBER , token , context .parser ());
426431 final int value ;
427432 if (context .parser ().numberType () != XContentParser .NumberType .INT ) {
@@ -459,30 +464,31 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
459464 vector [index ++] = (byte ) value ;
460465 squaredMagnitude += value * value ;
461466 }
462- fieldMapper . checkDimensionMatches (index , context );
463- checkVectorMagnitude (fieldMapper . fieldType (). similarity , errorByteElementsAppender (vector ), squaredMagnitude );
467+ dimChecker . accept (index , true );
468+ checkVectorMagnitude (similarity , errorByteElementsAppender (vector ), squaredMagnitude );
464469 return VectorData .fromBytes (vector );
465470 }
466471
467- private VectorData parseHexEncodedVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
472+ private VectorData parseHexEncodedVector (
473+ DocumentParserContext context ,
474+ IntBooleanConsumer dimChecker ,
475+ VectorSimilarity similarity
476+ ) throws IOException {
468477 byte [] decodedVector = HexFormat .of ().parseHex (context .parser ().text ());
469- fieldMapper . checkDimensionMatches (decodedVector .length , context );
478+ dimChecker . accept (decodedVector .length , true );
470479 VectorData vectorData = VectorData .fromBytes (decodedVector );
471480 double squaredMagnitude = computeSquaredMagnitude (vectorData );
472- checkVectorMagnitude (
473- fieldMapper .fieldType ().similarity ,
474- errorByteElementsAppender (decodedVector ),
475- (float ) squaredMagnitude
476- );
481+ checkVectorMagnitude (similarity , errorByteElementsAppender (decodedVector ), (float ) squaredMagnitude );
477482 return vectorData ;
478483 }
479484
480485 @ Override
481- VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
486+ VectorData parseKnnVector (DocumentParserContext context , int dims , IntBooleanConsumer dimChecker , VectorSimilarity similarity )
487+ throws IOException {
482488 XContentParser .Token token = context .parser ().currentToken ();
483489 return switch (token ) {
484- case START_ARRAY -> parseVectorArray (context , fieldMapper );
485- case VALUE_STRING -> parseHexEncodedVector (context , fieldMapper );
490+ case START_ARRAY -> parseVectorArray (context , dims , dimChecker , similarity );
491+ case VALUE_STRING -> parseHexEncodedVector (context , dimChecker , similarity );
486492 default -> throw new ParsingException (
487493 context .parser ().getTokenLocation (),
488494 format ("Unsupported type [%s] for provided value [%s]" , token , context .parser ().text ())
@@ -492,7 +498,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
492498
493499 @ Override
494500 public void parseKnnVectorAndIndex (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
495- VectorData vectorData = parseKnnVector (context , fieldMapper );
501+ VectorData vectorData = parseKnnVector (context , fieldMapper .fieldType ().dims , (i , end ) -> {
502+ if (end ) {
503+ fieldMapper .checkDimensionMatches (i , context );
504+ } else {
505+ fieldMapper .checkDimensionExceeded (i , context );
506+ }
507+ }, fieldMapper .fieldType ().similarity );
496508 Field field = createKnnVectorField (
497509 fieldMapper .fieldType ().name (),
498510 vectorData .asByteVector (),
@@ -676,21 +688,22 @@ && isNotUnitVector(squaredMagnitude)) {
676688 }
677689
678690 @ Override
679- VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
691+ VectorData parseKnnVector (DocumentParserContext context , int dims , IntBooleanConsumer dimChecker , VectorSimilarity similarity )
692+ throws IOException {
680693 int index = 0 ;
681694 float squaredMagnitude = 0 ;
682- float [] vector = new float [fieldMapper . fieldType (). dims ];
695+ float [] vector = new float [dims ];
683696 for (Token token = context .parser ().nextToken (); token != Token .END_ARRAY ; token = context .parser ().nextToken ()) {
684- fieldMapper . checkDimensionExceeded (index , context );
697+ dimChecker . accept (index , false );
685698 ensureExpectedToken (Token .VALUE_NUMBER , token , context .parser ());
686699 float value = context .parser ().floatValue (true );
687700 vector [index ] = value ;
688701 squaredMagnitude += value * value ;
689702 index ++;
690703 }
691- fieldMapper . checkDimensionMatches (index , context );
704+ dimChecker . accept (index , true );
692705 checkVectorBounds (vector );
693- checkVectorMagnitude (fieldMapper . fieldType (). similarity , errorFloatElementsAppender (vector ), squaredMagnitude );
706+ checkVectorMagnitude (similarity , errorFloatElementsAppender (vector ), squaredMagnitude );
694707 return VectorData .fromFloats (vector );
695708 }
696709
@@ -815,12 +828,17 @@ public double computeSquaredMagnitude(VectorData vectorData) {
815828 return count ;
816829 }
817830
818- private VectorData parseVectorArray (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
831+ private VectorData parseVectorArray (
832+ DocumentParserContext context ,
833+ int dims ,
834+ IntBooleanConsumer dimChecker ,
835+ VectorSimilarity similarity
836+ ) throws IOException {
819837 int index = 0 ;
820- byte [] vector = new byte [fieldMapper . fieldType (). dims / Byte .SIZE ];
838+ byte [] vector = new byte [dims / Byte .SIZE ];
821839 for (XContentParser .Token token = context .parser ().nextToken (); token != Token .END_ARRAY ; token = context .parser ()
822840 .nextToken ()) {
823- fieldMapper . checkDimensionExceeded (index , context );
841+ dimChecker . accept (index * Byte . SIZE , false );
824842 ensureExpectedToken (Token .VALUE_NUMBER , token , context .parser ());
825843 final int value ;
826844 if (context .parser ().numberType () != XContentParser .NumberType .INT ) {
@@ -855,35 +873,25 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
855873 + "];"
856874 );
857875 }
858- if (index >= vector .length ) {
859- throw new IllegalArgumentException (
860- "The number of dimensions for field ["
861- + fieldMapper .fieldType ().name ()
862- + "] should be ["
863- + fieldMapper .fieldType ().dims
864- + "] but found ["
865- + (index + 1 ) * Byte .SIZE
866- + "]"
867- );
868- }
869876 vector [index ++] = (byte ) value ;
870877 }
871- fieldMapper . checkDimensionMatches (index * Byte .SIZE , context );
878+ dimChecker . accept (index * Byte .SIZE , true );
872879 return VectorData .fromBytes (vector );
873880 }
874881
875- private VectorData parseHexEncodedVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
882+ private VectorData parseHexEncodedVector (DocumentParserContext context , IntBooleanConsumer dimChecker ) throws IOException {
876883 byte [] decodedVector = HexFormat .of ().parseHex (context .parser ().text ());
877- fieldMapper . checkDimensionMatches (decodedVector .length * Byte .SIZE , context );
884+ dimChecker . accept (decodedVector .length * Byte .SIZE , true );
878885 return VectorData .fromBytes (decodedVector );
879886 }
880887
881888 @ Override
882- VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
889+ VectorData parseKnnVector (DocumentParserContext context , int dims , IntBooleanConsumer dimChecker , VectorSimilarity similarity )
890+ throws IOException {
883891 XContentParser .Token token = context .parser ().currentToken ();
884892 return switch (token ) {
885- case START_ARRAY -> parseVectorArray (context , fieldMapper );
886- case VALUE_STRING -> parseHexEncodedVector (context , fieldMapper );
893+ case START_ARRAY -> parseVectorArray (context , dims , dimChecker , similarity );
894+ case VALUE_STRING -> parseHexEncodedVector (context , dimChecker );
887895 default -> throw new ParsingException (
888896 context .parser ().getTokenLocation (),
889897 format ("Unsupported type [%s] for provided value [%s]" , token , context .parser ().text ())
@@ -893,7 +901,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
893901
894902 @ Override
895903 public void parseKnnVectorAndIndex (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException {
896- VectorData vectorData = parseKnnVector (context , fieldMapper );
904+ VectorData vectorData = parseKnnVector (context , fieldMapper .fieldType ().dims , (i , end ) -> {
905+ if (end ) {
906+ fieldMapper .checkDimensionMatches (i , context );
907+ } else {
908+ fieldMapper .checkDimensionExceeded (i , context );
909+ }
910+ }, fieldMapper .fieldType ().similarity );
897911 Field field = createKnnVectorField (
898912 fieldMapper .fieldType ().name (),
899913 vectorData .asByteVector (),
@@ -957,7 +971,12 @@ public void checkDimensions(Integer dvDims, int qvDims) {
957971
958972 abstract void parseKnnVectorAndIndex (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException ;
959973
960- abstract VectorData parseKnnVector (DocumentParserContext context , DenseVectorFieldMapper fieldMapper ) throws IOException ;
974+ abstract VectorData parseKnnVector (
975+ DocumentParserContext context ,
976+ int dims ,
977+ IntBooleanConsumer dimChecker ,
978+ VectorSimilarity similarity
979+ ) throws IOException ;
961980
962981 abstract int getNumBytes (int dimensions );
963982
@@ -2179,7 +2198,13 @@ private void parseBinaryDocValuesVectorAndIndex(DocumentParserContext context) t
21792198 : elementType .getNumBytes (dims );
21802199
21812200 ByteBuffer byteBuffer = elementType .createByteBuffer (indexCreatedVersion , numBytes );
2182- VectorData vectorData = elementType .parseKnnVector (context , this );
2201+ VectorData vectorData = elementType .parseKnnVector (context , dims , (i , b ) -> {
2202+ if (b ) {
2203+ checkDimensionMatches (i , context );
2204+ } else {
2205+ checkDimensionExceeded (i , context );
2206+ }
2207+ }, fieldType ().similarity );
21832208 vectorData .addToBuffer (byteBuffer );
21842209 if (indexCreatedVersion .onOrAfter (MAGNITUDE_STORED_INDEX_VERSION )) {
21852210 // encode vector magnitude at the end
@@ -2427,4 +2452,11 @@ public String fieldName() {
24272452 return fullPath ();
24282453 }
24292454 }
2455+
2456+ /**
2457+ * @FunctionalInterface for a function that takes a int and boolean
2458+ */
2459+ interface IntBooleanConsumer {
2460+ void accept (int value , boolean isComplete );
2461+ }
24302462}
0 commit comments