Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126342.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126342
summary: Enable sort optimization on float and `half_float`
area: Search
type: enhancement
issues: []

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested;
import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.MultiValueMode;
Expand Down Expand Up @@ -46,7 +47,7 @@ public enum NumericType {
LONG(false, SortField.Type.LONG, CoreValuesSourceType.NUMERIC),
DATE(false, SortField.Type.LONG, CoreValuesSourceType.DATE),
DATE_NANOSECONDS(false, SortField.Type.LONG, CoreValuesSourceType.DATE),
HALF_FLOAT(true, SortField.Type.LONG, CoreValuesSourceType.NUMERIC),
HALF_FLOAT(true, SortField.Type.FLOAT, CoreValuesSourceType.NUMERIC),
FLOAT(true, SortField.Type.FLOAT, CoreValuesSourceType.NUMERIC),
DOUBLE(true, SortField.Type.DOUBLE, CoreValuesSourceType.NUMERIC);

Expand Down Expand Up @@ -95,11 +96,13 @@ public final SortField sortField(
* 3. We Aren't using max or min to resolve the duplicates.
* 4. We have to cast the results to another type.
*/
if (sortRequiresCustomComparator()
|| nested != null
boolean requiresCustomComparator = nested != null
|| (sortMode != MultiValueMode.MAX && sortMode != MultiValueMode.MIN)
|| targetNumericType != getNumericType()) {
return new SortField(getFieldName(), source, reverse);
|| targetNumericType != getNumericType();
if (sortRequiresCustomComparator() || requiresCustomComparator) {
SortField sortField = new SortField(getFieldName(), source, reverse);
sortField.setOptimizeSortWithPoints(requiresCustomComparator == false && isIndexed());
return sortField;
}

SortedNumericSelector.Type selectorType = sortMode == MultiValueMode.MAX
Expand All @@ -108,20 +111,18 @@ public final SortField sortField(
SortField sortField = new SortedNumericSortField(getFieldName(), getNumericType().sortFieldType, reverse, selectorType);
sortField.setMissingValue(source.missingObject(missingValue, reverse));

// TODO: Now that numeric sort uses indexed points to skip over non-competitive documents,
// Lucene 9 requires that the same data/type is stored in points and doc values.
// We break this assumption in ES by using the wider numeric sort type for every field,
// (e.g. shorts use longs and floats use doubles). So for now we forbid the usage of
// points in numeric sort on field types that use a different sort type.
// We could expose these optimizations for all numeric types but that would require
// to rewrite the logic to handle types when merging results coming from different
// indices.
// TODO: enable sort optimization for BYTE, SHORT and INT types
// They can use custom comparator logic, similarly to HalfFloatValuesComparatorSource.
// The problem comes from the fact that we use SortField.Type.LONG for all these types.
// Investigate how to resolve this.
switch (getNumericType()) {
case DATE_NANOSECONDS:
case DATE:
case LONG:
case DOUBLE:
// longs, doubles and dates use the same type for doc-values and points.
case FLOAT:
// longs, doubles and dates use the same type for doc-values and points
// floats uses longs for doc-values, but Lucene's FloatComparator::getValueForDoc converts long value to float
sortField.setOptimizeSortWithPoints(isIndexed());
break;

Expand Down Expand Up @@ -199,7 +200,8 @@ private XFieldComparatorSource comparatorSource(
Nested nested
) {
return switch (targetNumericType) {
case HALF_FLOAT, FLOAT -> new FloatValuesComparatorSource(this, missingValue, sortMode, nested);
case FLOAT -> new FloatValuesComparatorSource(this, missingValue, sortMode, nested);
case HALF_FLOAT -> new HalfFloatValuesComparatorSource(this, missingValue, sortMode, nested);
case DOUBLE -> new DoubleValuesComparatorSource(this, missingValue, sortMode, nested);
case DATE -> dateComparatorSource(missingValue, sortMode, nested);
case DATE_NANOSECONDS -> dateNanosComparatorSource(missingValue, sortMode, nested);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
*/
public class FloatValuesComparatorSource extends IndexFieldData.XFieldComparatorSource {

private final IndexNumericFieldData indexFieldData;
final IndexNumericFieldData indexFieldData;

public FloatValuesComparatorSource(
IndexNumericFieldData indexFieldData,
Expand All @@ -54,7 +54,7 @@ public SortField.Type reducedType() {
return SortField.Type.FLOAT;
}

private NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException {
NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException {
final SortedNumericDoubleValues values = indexFieldData.load(context).getDoubleValues();
if (nested == null) {
return FieldData.replaceMissing(sortMode.select(values), missingValue);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.fielddata.fieldcomparator;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.sandbox.document.HalfFloatPoint;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.comparators.NumericComparator;
import org.apache.lucene.util.BitUtil;

import java.io.IOException;

/**
* Comparator for hal_float values.
* This comparator provides a skipping functionality – an iterator that can skip over non-competitive documents.
*/
public class HalfFloatComparator extends NumericComparator<Float> {
private final float[] values;
protected float topValue;
protected float bottom;

public HalfFloatComparator(int numHits, String field, Float missingValue, boolean reverse, Pruning pruning) {
super(field, missingValue != null ? missingValue : 0.0f, reverse, pruning, HalfFloatPoint.BYTES);
values = new float[numHits];
}

@Override
public int compare(int slot1, int slot2) {
return Float.compare(values[slot1], values[slot2]);
}

@Override
public void setTopValue(Float value) {
super.setTopValue(value);
topValue = value;
}

@Override
public Float value(int slot) {
return Float.valueOf(values[slot]);
}

@Override
protected long missingValueAsComparableLong() {
return HalfFloatPoint.halfFloatToSortableShort(missingValue);
}

@Override
protected long sortableBytesToLong(byte[] bytes) {
// Copied form HalfFloatPoint::sortableBytesToShort
short x = (short) BitUtil.VH_BE_SHORT.get(bytes, 0);
// Re-flip the sign bit to restore the original value:
return (short) (x ^ 0x8000);
}

@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new HalfFloatLeafComparator(context);
}

/** Leaf comparator for {@link HalfFloatComparator} that provides skipping functionality */
public class HalfFloatLeafComparator extends NumericLeafComparator {

public HalfFloatLeafComparator(LeafReaderContext context) throws IOException {
super(context);
}

private float getValueForDoc(int doc) throws IOException {
if (docValues.advanceExact(doc)) {
return Float.intBitsToFloat((int) docValues.longValue());
} else {
return missingValue;
}
}

@Override
public void setBottom(int slot) throws IOException {
bottom = values[slot];
super.setBottom(slot);
}

@Override
public int compareBottom(int doc) throws IOException {
return Float.compare(bottom, getValueForDoc(doc));
}

@Override
public int compareTop(int doc) throws IOException {
return Float.compare(topValue, getValueForDoc(doc));
}

@Override
public void copy(int slot, int doc) throws IOException {
values[slot] = getValueForDoc(doc);
super.copy(slot, doc);
}

@Override
protected long bottomAsComparableLong() {
return HalfFloatPoint.halfFloatToSortableShort(bottom);
}

@Override
protected long topAsComparableLong() {
return HalfFloatPoint.halfFloatToSortableShort(topValue);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.fielddata.fieldcomparator;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.fielddata.IndexNumericFieldData;
import org.elasticsearch.search.MultiValueMode;

import java.io.IOException;

/**
* Comparator source for half_float values.
*/
public class HalfFloatValuesComparatorSource extends FloatValuesComparatorSource {
public HalfFloatValuesComparatorSource(
IndexNumericFieldData indexFieldData,
@Nullable Object missingValue,
MultiValueMode sortMode,
Nested nested
) {
super(indexFieldData, missingValue, sortMode, nested);
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final float fMissingValue = (Float) missingObject(missingValue, reversed);
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new HalfFloatComparator(numHits, fieldname, null, reversed, enableSkipping) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new HalfFloatLeafComparator(context) {
@Override
protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException {
return HalfFloatValuesComparatorSource.this.getNumericDocValues(context, fMissingValue).getRawFloatValues();
}
};
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.elasticsearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource;
import org.elasticsearch.search.MultiValueMode;
import org.elasticsearch.search.sort.ShardDocSortField;
Expand Down Expand Up @@ -627,7 +628,7 @@ private static Tuple<SortField, SortField> randomSortFieldCustomComparatorSource
IndexFieldData.XFieldComparatorSource comparatorSource;
boolean reverse = randomBoolean();
Object missingValue = null;
switch (randomIntBetween(0, 3)) {
switch (randomIntBetween(0, 4)) {
case 0 -> comparatorSource = new LongValuesComparatorSource(
null,
randomBoolean() ? randomLong() : null,
Expand All @@ -647,7 +648,13 @@ private static Tuple<SortField, SortField> randomSortFieldCustomComparatorSource
randomFrom(MultiValueMode.values()),
null
);
case 3 -> {
case 3 -> comparatorSource = new HalfFloatValuesComparatorSource(
null,
randomBoolean() ? randomFloat() : null,
randomFrom(MultiValueMode.values()),
null
);
case 4 -> {
comparatorSource = new BytesRefFieldComparatorSource(
null,
randomBoolean() ? "_first" : "_last",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ public <IFD extends IndexFieldData<?>> IFD getForField(String type, String field
IndexVersion.current(),
null
).docValues(docValues).build(context).fieldType();
} else if (type.equals("half_float")) {
fieldType = new NumberFieldMapper.Builder(
fieldName,
NumberFieldMapper.NumberType.HALF_FLOAT,
ScriptCompiler.NONE,
false,
true,
IndexVersion.current(),
null
).docValues(docValues).build(context).fieldType();
} else if (type.equals("double")) {
fieldType = new NumberFieldMapper.Builder(
fieldName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,13 @@

import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.QueryBitSetProducer;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.search.join.ToParentBlockJoinQuery;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource;
import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested;
import org.elasticsearch.index.fielddata.IndexNumericFieldData;
import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
import org.elasticsearch.search.MultiValueMode;

import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;

public class FloatNestedSortingTests extends DoubleNestedSortingTests {

@Override
Expand All @@ -55,39 +39,4 @@ protected IndexFieldData.XFieldComparatorSource createFieldComparator(
protected IndexableField createField(String name, int value) {
return new SortedNumericDocValuesField(name, NumericUtils.floatToSortableInt(value));
}

protected void assertAvgScoreMode(
Query parentFilter,
IndexSearcher searcher,
IndexFieldData.XFieldComparatorSource innerFieldComparator
) throws IOException {
MultiValueMode sortMode = MultiValueMode.AVG;
Query childFilter = Queries.not(parentFilter);
XFieldComparatorSource nestedComparatorSource = createFieldComparator(
"field2",
sortMode,
-127,
createNested(searcher, parentFilter, childFilter)
);
Query query = new ToParentBlockJoinQuery(
new ConstantScoreQuery(childFilter),
new QueryBitSetProducer(parentFilter),
ScoreMode.None
);
Sort sort = new Sort(new SortField("field2", nestedComparatorSource));
TopDocs topDocs = searcher.search(query, 5, sort);
assertThat(topDocs.totalHits.value(), equalTo(7L));
assertThat(topDocs.scoreDocs.length, equalTo(5));
assertThat(topDocs.scoreDocs[0].doc, equalTo(11));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[0]).fields[0]).intValue(), equalTo(2));
assertThat(topDocs.scoreDocs[1].doc, equalTo(7));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[1]).fields[0]).intValue(), equalTo(2));
assertThat(topDocs.scoreDocs[2].doc, equalTo(3));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[2]).fields[0]).intValue(), equalTo(3));
assertThat(topDocs.scoreDocs[3].doc, equalTo(15));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[3]).fields[0]).intValue(), equalTo(3));
assertThat(topDocs.scoreDocs[4].doc, equalTo(19));
assertThat(((Number) ((FieldDoc) topDocs.scoreDocs[4]).fields[0]).intValue(), equalTo(3));
}

}
Loading