diff --git a/docs/changelog/126702.yaml b/docs/changelog/126702.yaml new file mode 100644 index 0000000000000..a6def67c08c6d --- /dev/null +++ b/docs/changelog/126702.yaml @@ -0,0 +1,5 @@ +pr: 126702 +summary: "Return float[] instead of List in `valueFetcher`" +area: Search +type: enhancement +issues: [] diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java index 2686ab9c0b016..8f2e74f3bc130 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java @@ -180,7 +180,20 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) return new ArraySourceValueFetcher(name(), context) { @Override protected Object parseSourceValue(Object value) { - return value; + List outerList = (List) value; + List vectors = new ArrayList<>(outerList.size()); + for (Object o : outerList) { + if (o instanceof List innerList) { + float[] vector = new float[innerList.size()]; + for (int i = 0; i < vector.length; i++) { + vector[i] = ((Number) innerList.get(i)).floatValue(); + } + vectors.add(vector); + } else { + vectors.add(o); + } + } + return vectors; } }; } diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index 25264976c58a5..ac7ed8cf1c07f 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -383,19 +383,10 @@ protected void assertFetch(MapperService mapperService, String field, Object val switch (denseVectorFieldType.getElementType()) { case BYTE -> assumeFalse("byte element type testing not currently added", false); case FLOAT -> { - List fetchedFloatsList = new ArrayList<>(); - for (var f : fromNative) { - float[] fetchedFloats = new float[denseVectorFieldType.getVectorDimensions()]; - assert f instanceof List; - List vector = (List) f; - int i = 0; - for (Object v : vector) { - assert v instanceof Number; - fetchedFloats[i++] = ((Number) v).floatValue(); - } - fetchedFloatsList.add(fetchedFloats); + float[][] fetchedFloats = new float[fromNative.size()][]; + for (int i = 0; i < fromNative.size(); i++) { + fetchedFloats[i] = (float[]) fromNative.get(i); } - float[][] fetchedFloats = fetchedFloatsList.toArray(new float[0][]); assertThat("fetching " + value, fetchedFloats, equalTo(value)); } } diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldTypeTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldTypeTests.java index 59c5b0414910d..205a67accb79d 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldTypeTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldTypeTests.java @@ -19,10 +19,13 @@ import java.io.IOException; import java.util.Collections; +import java.util.HexFormat; import java.util.List; import java.util.Set; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.iterableWithSize; public class RankVectorsFieldTypeTests extends FieldTypeTestCase { @@ -86,9 +89,16 @@ public void testDocValueFormat() { public void testFetchSourceValue() throws IOException { RankVectorsFieldType fft = createFloatFieldType(); - List> vector = List.of(List.of(0.0, 1.0, 2.0, 3.0, 4.0, 6.0)); - assertEquals(vector, fetchSourceValue(fft, vector)); + List> vectorFromXContent = List.of(List.of(0.0, 1.0, 2.0, 3.0, 4.0, 6.0)); + List vector = List.of(new float[] { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 6.0f }); + assertThat(fetchSourceValue(fft, vectorFromXContent), iterableWithSize(1)); + assertThat(fetchSourceValue(fft, vectorFromXContent).get(0), equalTo(vector.get(0))); RankVectorsFieldType bft = createByteFieldType(); - assertEquals(vector, fetchSourceValue(bft, vector)); + assertThat(fetchSourceValue(bft, vectorFromXContent), iterableWithSize(1)); + assertThat(fetchSourceValue(bft, vectorFromXContent).get(0), equalTo(vector.get(0))); + String hexStr = HexFormat.of().formatHex(new byte[] { 0, 1, 2, 3, 4, 6 }); + List hexVecs = List.of(hexStr); + assertThat(fetchSourceValue(bft, hexVecs), iterableWithSize(1)); + assertThat(fetchSourceValue(bft, hexVecs).get(0), equalTo(hexStr)); } }