diff --git a/docs/changelog/126342.yaml b/docs/changelog/126342.yaml new file mode 100644 index 0000000000000..b594deec97de5 --- /dev/null +++ b/docs/changelog/126342.yaml @@ -0,0 +1,5 @@ +pr: 126342 +summary: Enable sort optimization on float and `half_float` +area: Search +type: enhancement +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java index 1d9bc96582ffb..fde4e944f924d 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java @@ -524,6 +524,9 @@ public void testSimpleSorts() throws Exception { .startObject("float_value") .field("type", "float") .endObject() + .startObject("half_float_value") + .field("type", "half_float") + .endObject() .startObject("double_value") .field("type", "double") .endObject() @@ -534,7 +537,8 @@ public void testSimpleSorts() throws Exception { ); ensureGreen(); List builders = new ArrayList<>(); - for (int i = 0; i < 10; i++) { + final int numDocs = randomIntBetween(10, 127); + for (int i = 0; i < numDocs; i++) { IndexRequestBuilder builder = prepareIndex("test").setId(Integer.toString(i)) .setSource( jsonBuilder().startObject() @@ -545,6 +549,7 @@ public void testSimpleSorts() throws Exception { .field("integer_value", i) .field("long_value", i) .field("float_value", 0.1 * i) + .field("half_float_value", 0.1 * i) .field("double_value", 0.1 * i) .endObject() ); @@ -566,9 +571,9 @@ public void testSimpleSorts() throws Exception { // STRING { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("str_value", SortOrder.ASC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); @@ -580,15 +585,17 @@ public void testSimpleSorts() throws Exception { }); } { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("str_value", SortOrder.DESC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { - assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); assertThat( - response.getHits().getAt(i).getSortValues()[0].toString(), - equalTo(new String(new char[] { (char) (97 + (9 - i)), (char) (97 + (9 - i)) })) + hit.getSortValues()[0].toString(), + equalTo(new String(new char[] { (char) (97 + expectedValue), (char) (97 + expectedValue) })) ); } assertThat(response.toString(), not(containsString("error"))); @@ -596,9 +603,9 @@ public void testSimpleSorts() throws Exception { } // BYTE { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("byte_value", SortOrder.ASC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); @@ -607,22 +614,24 @@ public void testSimpleSorts() throws Exception { }); } { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("byte_value", SortOrder.DESC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { - assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); - assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).byteValue(), equalTo((byte) (9 - i))); + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); + assertThat(((Number) hit.getSortValues()[0]).byteValue(), equalTo((byte) expectedValue)); } assertThat(response.toString(), not(containsString("error"))); }); } // SHORT { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("short_value", SortOrder.ASC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); @@ -631,22 +640,24 @@ public void testSimpleSorts() throws Exception { }); } { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("short_value", SortOrder.DESC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { - assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); - assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).shortValue(), equalTo((short) (9 - i))); + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); + assertThat(((Number) hit.getSortValues()[0]).shortValue(), equalTo((short) expectedValue)); } assertThat(response.toString(), not(containsString("error"))); }); } // INTEGER { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("integer_value", SortOrder.ASC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); @@ -656,23 +667,24 @@ public void testSimpleSorts() throws Exception { }); } { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("integer_value", SortOrder.DESC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { - assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); - assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).intValue(), equalTo((9 - i))); + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); + assertThat(((Number) hit.getSortValues()[0]).intValue(), equalTo(expectedValue)); } - assertThat(response.toString(), not(containsString("error"))); }); } // LONG { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("long_value", SortOrder.ASC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); @@ -683,22 +695,24 @@ public void testSimpleSorts() throws Exception { }); } { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("long_value", SortOrder.DESC), response -> { - assertHitCount(response, 10L); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { - assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); - assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).longValue(), equalTo((long) (9 - i))); + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); + assertThat(((Number) hit.getSortValues()[0]).longValue(), equalTo((long) expectedValue)); } assertThat(response.toString(), not(containsString("error"))); }); } // FLOAT { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("float_value", SortOrder.ASC), response -> { - assertHitCount(response, 10L); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); @@ -708,22 +722,82 @@ public void testSimpleSorts() throws Exception { }); } { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("float_value", SortOrder.DESC), response -> { - assertHitCount(response, 10); + assertHitCount(response, numDocs); + assertThat(response.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); + assertThat(((Number) hit.getSortValues()[0]).doubleValue(), closeTo(0.1d * expectedValue, 0.000001d)); + } + assertThat(response.toString(), not(containsString("error"))); + }); + } + { + // assert correctness of cast floats during sort (using numeric_type); no sort optimization is used + int size = 1 + random.nextInt(numDocs); + FieldSortBuilder sort = SortBuilders.fieldSort("float_value").order(SortOrder.ASC).setNumericType("double"); + assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort(sort), response -> { + assertHitCount(response, numDocs); + assertThat(response.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); + assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).doubleValue(), closeTo(0.1d * i, 0.000001d)); + } + assertThat(response.toString(), not(containsString("error"))); + }); + } + // HALF-FLOAT + { + int size = 1 + random.nextInt(numDocs); + assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.ASC), response -> { + assertHitCount(response, numDocs); + assertThat(response.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); + assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).doubleValue(), closeTo(0.1d * i, 0.004d)); + } + assertThat(response.toString(), not(containsString("error"))); + }); + } + { + int size = 1 + random.nextInt(numDocs); + assertResponse( + prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.DESC), + response -> { + assertHitCount(response, numDocs); + assertThat(response.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); + assertThat(((Number) hit.getSortValues()[0]).doubleValue(), closeTo(0.1d * expectedValue, 0.004d)); + } + assertThat(response.toString(), not(containsString("error"))); + } + ); + } + { + // assert correctness of cast half_floats during sort (using numeric_type); no sort optimization is used + int size = 1 + random.nextInt(numDocs); + FieldSortBuilder sort = SortBuilders.fieldSort("half_float_value").order(SortOrder.ASC).setNumericType("double"); + assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort(sort), response -> { + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { - assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); - assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).doubleValue(), closeTo(0.1d * (9 - i), 0.000001d)); + assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); + assertThat(((Number) response.getHits().getAt(i).getSortValues()[0]).doubleValue(), closeTo(0.1d * i, 0.004)); } assertThat(response.toString(), not(containsString("error"))); }); } // DOUBLE { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertResponse(prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("double_value", SortOrder.ASC), response -> { - assertHitCount(response, 10L); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); @@ -733,18 +807,17 @@ public void testSimpleSorts() throws Exception { }); } { - int size = 1 + random.nextInt(10); + int size = 1 + random.nextInt(numDocs); assertNoFailuresAndResponse( prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("double_value", SortOrder.DESC), response -> { - assertHitCount(response, 10L); + assertHitCount(response, numDocs); assertThat(response.getHits().getHits().length, equalTo(size)); for (int i = 0; i < size; i++) { - assertThat(response.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); - assertThat( - ((Number) response.getHits().getAt(i).getSortValues()[0]).doubleValue(), - closeTo(0.1d * (9 - i), 0.000001d) - ); + int expectedValue = numDocs - 1 - i; + SearchHit hit = response.getHits().getAt(i); + assertThat(hit.getId(), equalTo(Integer.toString(expectedValue))); + assertThat(((Number) hit.getSortValues()[0]).doubleValue(), closeTo(0.1d * expectedValue, 0.000001d)); } } ); @@ -756,63 +829,129 @@ public void testSortMissingNumbers() throws Exception { prepareCreate("test").setMapping( XContentFactory.jsonBuilder() .startObject() - .startObject("_doc") .startObject("properties") - .startObject("i_value") + .startObject("float_value") + .field("type", "float") + .endObject() + .startObject("int_value") .field("type", "integer") .endObject() - .startObject("d_value") - .field("type", "float") + .startObject("byte_value") + .field("type", "byte") .endObject() + .startObject("short_value") + .field("type", "short") + .endObject() + .startObject("long_value") + .field("type", "long") + .endObject() + .startObject("half_float_value") + .field("type", "half_float") + .endObject() + .startObject("double_value") + .field("type", "double") + .endObject() + .startObject("id") + .field("type", "keyword") .endObject() .endObject() .endObject() ) ); ensureGreen(); - prepareIndex("test").setId("1") - .setSource(jsonBuilder().startObject().field("id", "1").field("i_value", -1).field("d_value", -1.1).endObject()) - .get(); - prepareIndex("test").setId("2").setSource(jsonBuilder().startObject().field("id", "2").endObject()).get(); - - prepareIndex("test").setId("3") - .setSource(jsonBuilder().startObject().field("id", "1").field("i_value", 2).field("d_value", 2.2).endObject()) - .get(); + int numDocs = randomIntBetween(50, 127); + int missingRatio = 3; + BulkRequestBuilder bulk = client().prepareBulk(); - flush(); + List docsWithValues = new ArrayList<>(); + int misCount = 0; + for (int i = 0; i < numDocs; i++) { + if (i % missingRatio == 0) { + bulk.add( + prepareIndex("test").setId(Integer.toString(i)) + .setSource(jsonBuilder().startObject().field("id", Integer.toString(i)).endObject()) + ); + misCount++; + } else { + byte byteValue = (byte) (i % 127); + short shortValue = (short) (i * 2); + int intValue = i; + long longValue = i * 1000L; + float floatValue = (float) (i * 0.1); + float halfFloatValue = floatValue; + double doubleValue = i * 0.001; + bulk.add( + prepareIndex("test").setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field("id", Integer.toString(i)) + .field("byte_value", byteValue) + .field("short_value", shortValue) + .field("int_value", intValue) + .field("long_value", longValue) + .field("float_value", floatValue) + .field("half_float_value", halfFloatValue) + .field("double_value", doubleValue) + .endObject() + ) + ); + docsWithValues.add(i); + } + } + assertNoFailures(bulk.get()); refresh(); + final int missingCount = misCount; + final int withValuesCount = docsWithValues.size(); + + String[] fieldTypes = new String[] { + "byte_value", + "short_value", + "int_value", + "long_value", + "float_value", + "half_float_value", + "double_value" }; + + for (String fieldName : fieldTypes) { + // Test sorting with missing _last (default behavior) + assertNoFailuresAndResponse( + prepareSearch().setSize(numDocs).setQuery(matchAllQuery()).addSort(SortBuilders.fieldSort(fieldName).order(SortOrder.ASC)), + response -> { + assertEquals(numDocs, response.getHits().getHits().length); + for (int i = 0; i < docsWithValues.size(); i++) { + int expectedDocId = docsWithValues.get(i); + int actualDocId = Integer.parseInt(response.getHits().getAt(i).getId()); + assertEquals("Field " + fieldName + ": wrong doc at position " + i, expectedDocId, actualDocId); + } + // all documents with missing values should appear at the end + for (int i = 0; i < missingCount; i++) { + int actualDocId = Integer.parseInt(response.getHits().getAt(withValuesCount + i).getId()); + assertThat("Field " + fieldName + ": wrong missing doc at position " + i, actualDocId % missingRatio, equalTo(0)); + } + } + ); - logger.info("--> sort with no missing (same as missing _last)"); - assertNoFailuresAndResponse( - prepareSearch().setQuery(matchAllQuery()).addSort(SortBuilders.fieldSort("i_value").order(SortOrder.ASC)), - response -> { - assertThat(response.getHits().getTotalHits().value(), equalTo(3L)); - assertThat(response.getHits().getAt(0).getId(), equalTo("1")); - assertThat(response.getHits().getAt(1).getId(), equalTo("3")); - assertThat(response.getHits().getAt(2).getId(), equalTo("2")); - } - ); - logger.info("--> sort with missing _last"); - assertNoFailuresAndResponse( - prepareSearch().setQuery(matchAllQuery()).addSort(SortBuilders.fieldSort("i_value").order(SortOrder.ASC).missing("_last")), - response -> { - assertThat(response.getHits().getTotalHits().value(), equalTo(3L)); - assertThat(response.getHits().getAt(0).getId(), equalTo("1")); - assertThat(response.getHits().getAt(1).getId(), equalTo("3")); - assertThat(response.getHits().getAt(2).getId(), equalTo("2")); - } - ); - logger.info("--> sort with missing _first"); - assertNoFailuresAndResponse( - prepareSearch().setQuery(matchAllQuery()).addSort(SortBuilders.fieldSort("i_value").order(SortOrder.ASC).missing("_first")), - response -> { - assertThat(response.getHits().getTotalHits().value(), equalTo(3L)); - assertThat(response.getHits().getAt(0).getId(), equalTo("2")); - assertThat(response.getHits().getAt(1).getId(), equalTo("1")); - assertThat(response.getHits().getAt(2).getId(), equalTo("3")); - } - ); + // Test sorting with missing _first + assertNoFailuresAndResponse( + prepareSearch().setSize(numDocs) + .setQuery(matchAllQuery()) + .addSort(SortBuilders.fieldSort(fieldName).order(SortOrder.ASC).missing("_first")), + response -> { + assertEquals(numDocs, response.getHits().getHits().length); + // all documents with missing values should appear at the beginning + for (int i = 0; i < missingCount; i++) { + int actualDocId = Integer.parseInt(response.getHits().getAt(i).getId()); + assertThat("Field " + fieldName + ": wrong missing doc at position " + i, actualDocId % missingRatio, equalTo(0)); + } + for (int i = 0; i < docsWithValues.size(); i++) { + int expectedDocId = docsWithValues.get(i); + int actualDocId = Integer.parseInt(response.getHits().getAt(i + missingCount).getId()); + assertEquals("Field " + fieldName + ": wrong doc at position " + i, expectedDocId, actualDocId); + } + } + ); + } } public void testSortMissingStrings() throws IOException { diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/IndexNumericFieldData.java b/server/src/main/java/org/elasticsearch/index/fielddata/IndexNumericFieldData.java index 289f1dd6abd25..98a5c8aed23c9 100644 --- a/server/src/main/java/org/elasticsearch/index/fielddata/IndexNumericFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/fielddata/IndexNumericFieldData.java @@ -19,6 +19,7 @@ import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource; import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.MultiValueMode; @@ -46,7 +47,7 @@ public enum NumericType { LONG(false, SortField.Type.LONG, CoreValuesSourceType.NUMERIC), DATE(false, SortField.Type.LONG, CoreValuesSourceType.DATE), DATE_NANOSECONDS(false, SortField.Type.LONG, CoreValuesSourceType.DATE), - HALF_FLOAT(true, SortField.Type.LONG, CoreValuesSourceType.NUMERIC), + HALF_FLOAT(true, SortField.Type.FLOAT, CoreValuesSourceType.NUMERIC), FLOAT(true, SortField.Type.FLOAT, CoreValuesSourceType.NUMERIC), DOUBLE(true, SortField.Type.DOUBLE, CoreValuesSourceType.NUMERIC); @@ -95,11 +96,13 @@ public final SortField sortField( * 3. We Aren't using max or min to resolve the duplicates. * 4. We have to cast the results to another type. */ - if (sortRequiresCustomComparator() - || nested != null + boolean requiresCustomComparator = nested != null || (sortMode != MultiValueMode.MAX && sortMode != MultiValueMode.MIN) - || targetNumericType != getNumericType()) { - return new SortField(getFieldName(), source, reverse); + || targetNumericType != getNumericType(); + if (sortRequiresCustomComparator() || requiresCustomComparator) { + SortField sortField = new SortField(getFieldName(), source, reverse); + sortField.setOptimizeSortWithPoints(requiresCustomComparator == false && isIndexed()); + return sortField; } SortedNumericSelector.Type selectorType = sortMode == MultiValueMode.MAX @@ -108,20 +111,18 @@ public final SortField sortField( SortField sortField = new SortedNumericSortField(getFieldName(), getNumericType().sortFieldType, reverse, selectorType); sortField.setMissingValue(source.missingObject(missingValue, reverse)); - // TODO: Now that numeric sort uses indexed points to skip over non-competitive documents, - // Lucene 9 requires that the same data/type is stored in points and doc values. - // We break this assumption in ES by using the wider numeric sort type for every field, - // (e.g. shorts use longs and floats use doubles). So for now we forbid the usage of - // points in numeric sort on field types that use a different sort type. - // We could expose these optimizations for all numeric types but that would require - // to rewrite the logic to handle types when merging results coming from different - // indices. + // TODO: enable sort optimization for BYTE, SHORT and INT types + // They can use custom comparator logic, similarly to HalfFloatValuesComparatorSource. + // The problem comes from the fact that we use SortField.Type.LONG for all these types. + // Investigate how to resolve this. switch (getNumericType()) { case DATE_NANOSECONDS: case DATE: case LONG: case DOUBLE: - // longs, doubles and dates use the same type for doc-values and points. + case FLOAT: + // longs, doubles and dates use the same type for doc-values and points + // floats uses longs for doc-values, but Lucene's FloatComparator::getValueForDoc converts long value to float sortField.setOptimizeSortWithPoints(isIndexed()); break; @@ -199,7 +200,8 @@ private XFieldComparatorSource comparatorSource( Nested nested ) { return switch (targetNumericType) { - case HALF_FLOAT, FLOAT -> new FloatValuesComparatorSource(this, missingValue, sortMode, nested); + case FLOAT -> new FloatValuesComparatorSource(this, missingValue, sortMode, nested); + case HALF_FLOAT -> new HalfFloatValuesComparatorSource(this, missingValue, sortMode, nested); case DOUBLE -> new DoubleValuesComparatorSource(this, missingValue, sortMode, nested); case DATE -> dateComparatorSource(missingValue, sortMode, nested); case DATE_NANOSECONDS -> dateNanosComparatorSource(missingValue, sortMode, nested); diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java index f1e4ba95e76fe..79f1fdb25a0a6 100644 --- a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/FloatValuesComparatorSource.java @@ -37,7 +37,7 @@ */ public class FloatValuesComparatorSource extends IndexFieldData.XFieldComparatorSource { - private final IndexNumericFieldData indexFieldData; + final IndexNumericFieldData indexFieldData; public FloatValuesComparatorSource( IndexNumericFieldData indexFieldData, @@ -54,7 +54,7 @@ public SortField.Type reducedType() { return SortField.Type.FLOAT; } - private NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException { + NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException { final SortedNumericDoubleValues values = indexFieldData.load(context).getDoubleValues(); if (nested == null) { return FieldData.replaceMissing(sortMode.select(values), missingValue); diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/HalfFloatComparator.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/HalfFloatComparator.java new file mode 100644 index 0000000000000..efa41887a22fc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/HalfFloatComparator.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.sandbox.document.HalfFloatPoint; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.Pruning; +import org.apache.lucene.search.comparators.NumericComparator; +import org.apache.lucene.util.BitUtil; + +import java.io.IOException; + +/** + * Comparator for hal_float values. + * This comparator provides a skipping functionality – an iterator that can skip over non-competitive documents. + */ +public class HalfFloatComparator extends NumericComparator { + private final float[] values; + protected float topValue; + protected float bottom; + + public HalfFloatComparator(int numHits, String field, Float missingValue, boolean reverse, Pruning pruning) { + super(field, missingValue != null ? missingValue : 0.0f, reverse, pruning, HalfFloatPoint.BYTES); + values = new float[numHits]; + } + + @Override + public int compare(int slot1, int slot2) { + return Float.compare(values[slot1], values[slot2]); + } + + @Override + public void setTopValue(Float value) { + super.setTopValue(value); + topValue = value; + } + + @Override + public Float value(int slot) { + return Float.valueOf(values[slot]); + } + + @Override + protected long missingValueAsComparableLong() { + return HalfFloatPoint.halfFloatToSortableShort(missingValue); + } + + @Override + protected long sortableBytesToLong(byte[] bytes) { + // Copied form HalfFloatPoint::sortableBytesToShort + short x = (short) BitUtil.VH_BE_SHORT.get(bytes, 0); + // Re-flip the sign bit to restore the original value: + return (short) (x ^ 0x8000); + } + + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + return new HalfFloatLeafComparator(context); + } + + /** Leaf comparator for {@link HalfFloatComparator} that provides skipping functionality */ + public class HalfFloatLeafComparator extends NumericLeafComparator { + + public HalfFloatLeafComparator(LeafReaderContext context) throws IOException { + super(context); + } + + private float getValueForDoc(int doc) throws IOException { + if (docValues.advanceExact(doc)) { + return Float.intBitsToFloat((int) docValues.longValue()); + } else { + return missingValue; + } + } + + @Override + public void setBottom(int slot) throws IOException { + bottom = values[slot]; + super.setBottom(slot); + } + + @Override + public int compareBottom(int doc) throws IOException { + return Float.compare(bottom, getValueForDoc(doc)); + } + + @Override + public int compareTop(int doc) throws IOException { + return Float.compare(topValue, getValueForDoc(doc)); + } + + @Override + public void copy(int slot, int doc) throws IOException { + values[slot] = getValueForDoc(doc); + super.copy(slot, doc); + } + + @Override + protected long bottomAsComparableLong() { + return HalfFloatPoint.halfFloatToSortableShort(bottom); + } + + @Override + protected long topAsComparableLong() { + return HalfFloatPoint.halfFloatToSortableShort(topValue); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java new file mode 100644 index 0000000000000..ade3f5ccc5a3a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.Pruning; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.fielddata.IndexNumericFieldData; +import org.elasticsearch.search.MultiValueMode; + +import java.io.IOException; + +/** + * Comparator source for half_float values. + */ +public class HalfFloatValuesComparatorSource extends FloatValuesComparatorSource { + public HalfFloatValuesComparatorSource( + IndexNumericFieldData indexFieldData, + @Nullable Object missingValue, + MultiValueMode sortMode, + Nested nested + ) { + super(indexFieldData, missingValue, sortMode, nested); + } + + @Override + public FieldComparator newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) { + assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName()); + + final float fMissingValue = (Float) missingObject(missingValue, reversed); + // NOTE: it's important to pass null as a missing value in the constructor so that + // the comparator doesn't check docsWithField since we replace missing values in select() + return new HalfFloatComparator(numHits, fieldname, null, reversed, enableSkipping) { + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + return new HalfFloatLeafComparator(context) { + @Override + protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException { + return HalfFloatValuesComparatorSource.this.getNumericDocValues(context, fMissingValue).getRawFloatValues(); + } + }; + } + }; + } +} diff --git a/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java b/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java index d5fb33c9ec671..f6adb779055b0 100644 --- a/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java +++ b/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java @@ -60,6 +60,7 @@ import org.elasticsearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource; import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource; import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource; import org.elasticsearch.search.MultiValueMode; import org.elasticsearch.search.sort.ShardDocSortField; @@ -627,7 +628,7 @@ private static Tuple randomSortFieldCustomComparatorSource IndexFieldData.XFieldComparatorSource comparatorSource; boolean reverse = randomBoolean(); Object missingValue = null; - switch (randomIntBetween(0, 3)) { + switch (randomIntBetween(0, 4)) { case 0 -> comparatorSource = new LongValuesComparatorSource( null, randomBoolean() ? randomLong() : null, @@ -647,7 +648,13 @@ private static Tuple randomSortFieldCustomComparatorSource randomFrom(MultiValueMode.values()), null ); - case 3 -> { + case 3 -> comparatorSource = new HalfFloatValuesComparatorSource( + null, + randomBoolean() ? randomFloat() : null, + randomFrom(MultiValueMode.values()), + null + ); + case 4 -> { comparatorSource = new BytesRefFieldComparatorSource( null, randomBoolean() ? "_first" : "_last", diff --git a/server/src/test/java/org/elasticsearch/index/fielddata/AbstractFieldDataTestCase.java b/server/src/test/java/org/elasticsearch/index/fielddata/AbstractFieldDataTestCase.java index f809a53d753fb..7ce3a0f6acabb 100644 --- a/server/src/test/java/org/elasticsearch/index/fielddata/AbstractFieldDataTestCase.java +++ b/server/src/test/java/org/elasticsearch/index/fielddata/AbstractFieldDataTestCase.java @@ -108,6 +108,16 @@ public > IFD getForField(String type, String field IndexVersion.current(), null ).docValues(docValues).build(context).fieldType(); + } else if (type.equals("half_float")) { + fieldType = new NumberFieldMapper.Builder( + fieldName, + NumberFieldMapper.NumberType.HALF_FLOAT, + ScriptCompiler.NONE, + false, + true, + IndexVersion.current(), + null + ).docValues(docValues).build(context).fieldType(); } else if (type.equals("double")) { fieldType = new NumberFieldMapper.Builder( fieldName, diff --git a/server/src/test/java/org/elasticsearch/index/search/nested/FloatNestedSortingTests.java b/server/src/test/java/org/elasticsearch/index/search/nested/FloatNestedSortingTests.java index 60e7473a2101a..0682734633a0f 100644 --- a/server/src/test/java/org/elasticsearch/index/search/nested/FloatNestedSortingTests.java +++ b/server/src/test/java/org/elasticsearch/index/search/nested/FloatNestedSortingTests.java @@ -10,29 +10,13 @@ import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.index.IndexableField; -import org.apache.lucene.search.ConstantScoreQuery; -import org.apache.lucene.search.FieldDoc; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.Sort; -import org.apache.lucene.search.SortField; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.join.QueryBitSetProducer; -import org.apache.lucene.search.join.ScoreMode; -import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.util.NumericUtils; -import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.index.fielddata.IndexFieldData; -import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource; import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; import org.elasticsearch.index.fielddata.IndexNumericFieldData; import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource; import org.elasticsearch.search.MultiValueMode; -import java.io.IOException; - -import static org.hamcrest.Matchers.equalTo; - public class FloatNestedSortingTests extends DoubleNestedSortingTests { @Override @@ -55,39 +39,4 @@ protected IndexFieldData.XFieldComparatorSource createFieldComparator( protected IndexableField createField(String name, int value) { return new SortedNumericDocValuesField(name, NumericUtils.floatToSortableInt(value)); } - - protected void assertAvgScoreMode( - Query parentFilter, - IndexSearcher searcher, - IndexFieldData.XFieldComparatorSource innerFieldComparator - ) throws IOException { - MultiValueMode sortMode = MultiValueMode.AVG; - Query childFilter = Queries.not(parentFilter); - XFieldComparatorSource nestedComparatorSource = createFieldComparator( - "field2", - sortMode, - -127, - createNested(searcher, parentFilter, childFilter) - ); - Query query = new ToParentBlockJoinQuery( - new ConstantScoreQuery(childFilter), - new QueryBitSetProducer(parentFilter), - ScoreMode.None - ); - Sort sort = new Sort(new SortField("field2", nestedComparatorSource)); - TopDocs topDocs = searcher.search(query, 5, sort); - assertThat(topDocs.totalHits.value(), equalTo(7L)); - assertThat(topDocs.scoreDocs.length, equalTo(5)); - assertThat(topDocs.scoreDocs[0].doc, equalTo(11)); - assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[0]).fields[0]).intValue(), equalTo(2)); - assertThat(topDocs.scoreDocs[1].doc, equalTo(7)); - assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[1]).fields[0]).intValue(), equalTo(2)); - assertThat(topDocs.scoreDocs[2].doc, equalTo(3)); - assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[2]).fields[0]).intValue(), equalTo(3)); - assertThat(topDocs.scoreDocs[3].doc, equalTo(15)); - assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[3]).fields[0]).intValue(), equalTo(3)); - assertThat(topDocs.scoreDocs[4].doc, equalTo(19)); - assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[4]).fields[0]).intValue(), equalTo(3)); - } - } diff --git a/server/src/test/java/org/elasticsearch/index/search/nested/HalfFloatNestedSortingTests.java b/server/src/test/java/org/elasticsearch/index/search/nested/HalfFloatNestedSortingTests.java new file mode 100644 index 0000000000000..671ff7967bfcb --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/search/nested/HalfFloatNestedSortingTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.index.search.nested; + +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.sandbox.document.HalfFloatPoint; +import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource; +import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; +import org.elasticsearch.index.fielddata.IndexNumericFieldData; +import org.elasticsearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource; +import org.elasticsearch.search.MultiValueMode; + +public class HalfFloatNestedSortingTests extends DoubleNestedSortingTests { + + @Override + protected String getFieldDataType() { + return "half_float"; + } + + @Override + protected XFieldComparatorSource createFieldComparator(String fieldName, MultiValueMode sortMode, Object missingValue, Nested nested) { + IndexNumericFieldData fieldData = getForField(fieldName); + return new HalfFloatValuesComparatorSource(fieldData, missingValue, sortMode, nested); + } + + @Override + protected IndexableField createField(String name, int value) { + return new SortedNumericDocValuesField(name, HalfFloatPoint.halfFloatToSortableShort(value)); + } +}