Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/123757.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 123757
summary: Fix concurrency issue in `ScriptSortBuilder`
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,118 @@
}]
- match: { error.root_cause.0.type: "illegal_argument_exception" }
- match: { error.root_cause.0.reason: "script score function must not produce negative scores, but got: [-9.0]"}

---

"Script Sort + _score":
- do:
index:
index: test
id: "1"
body: { "test": "a", "num1": 1.0, "type" : "first" }
- do:
index:
index: test
id: "2"
body: { "test": "b", "num1": 2.0, "type" : "first" }
- do:
index:
index: test
id: "3"
body: { "test": "c", "num1": 3.0, "type" : "first" }
- do:
index:
index: test
id: "4"
body: { "test": "d", "num1": 4.0, "type" : "second" }
- do:
index:
index: test
id: "5"
body: { "test": "e", "num1": 5.0, "type" : "second" }
- do:
indices.refresh: {}

- do:
search:
rest_total_hits_as_int: true
index: test
body:
sort: [
{
_script: {
script: {
lang: "painless",
source: "doc['num1'].value + _score"
},
type: "number"
}
}
]

- match: { hits.total: 5 }
- match: { hits.hits.0.sort.0: 2.0 }
- match: { hits.hits.1.sort.0: 3.0 }
- match: { hits.hits.2.sort.0: 4.0 }
- match: { hits.hits.3.sort.0: 5.0 }
- match: { hits.hits.4.sort.0: 6.0 }

- do:
search:
rest_total_hits_as_int: true
index: test
body:
sort: [
{
_script: {
script: {
lang: "painless",
source: "doc['test.keyword'].value + '-' + _score"
},
type: "string"
}
}
]

- match: { hits.total: 5 }
- match: { hits.hits.0.sort.0: "a-1.0" }
- match: { hits.hits.1.sort.0: "b-1.0" }
- match: { hits.hits.2.sort.0: "c-1.0" }
- match: { hits.hits.3.sort.0: "d-1.0" }
- match: { hits.hits.4.sort.0: "e-1.0" }

- do:
search:
rest_total_hits_as_int: true
index: test
body:
aggs:
test:
terms:
field: type.keyword
aggs:
top_hits:
top_hits:
sort: [
{
_script: {
script: {
lang: "painless",
source: "doc['test.keyword'].value + '-' + _score"
},
type: "string"
}
},
"_score"
]
size: 1

- match: { hits.total: 5 }
- match: { aggregations.test.buckets.0.key: "first" }
- match: { aggregations.test.buckets.0.top_hits.hits.total: 3 }
- match: { aggregations.test.buckets.0.top_hits.hits.hits.0.sort.0: "a-1.0" }
- match: { aggregations.test.buckets.0.top_hits.hits.hits.0.sort.1: 1.0 }
- match: { aggregations.test.buckets.1.key: "second" }
- match: { aggregations.test.buckets.1.top_hits.hits.total: 2 }
- match: { aggregations.test.buckets.1.top_hits.hits.hits.0.sort.0: "d-1.0" }
- match: { aggregations.test.buckets.1.top_hits.hits.hits.0.sort.1: 1.0 }
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.seqno.SequenceNumbers;
Expand Down Expand Up @@ -64,6 +65,7 @@
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
import static org.elasticsearch.index.query.QueryBuilders.nestedQuery;
import static org.elasticsearch.script.MockScriptPlugin.NAME;
import static org.elasticsearch.search.aggregations.AggregationBuilders.global;
import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram;
import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
Expand Down Expand Up @@ -102,7 +104,12 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
public static class CustomScriptPlugin extends MockScriptPlugin {
@Override
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
return Collections.singletonMap("5", script -> "5");
return Map.of("5", script -> "5", "doc['sort'].value", CustomScriptPlugin::sortDoubleScript);
}

private static Double sortDoubleScript(Map<String, Object> vars) {
Map<?, ?> doc = (Map) vars.get("doc");
return ((Number) ((ScriptDocValues<?>) doc.get("sort")).get(0)).doubleValue();
}

@Override
Expand Down Expand Up @@ -1268,6 +1275,41 @@ public void testWithRescore() {
);
}

public void testScriptSorting() {
Script script = new Script(ScriptType.INLINE, NAME, "doc['sort'].value", Collections.emptyMap());
assertNoFailuresAndResponse(
prepareSearch("idx").addAggregation(
terms("terms").executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD)
.subAggregation(topHits("hits").sort(SortBuilders.scriptSort(script, ScriptSortType.NUMBER).order(SortOrder.DESC)))
),
response -> {
Terms terms = response.getAggregations().get("terms");
assertThat(terms, notNullValue());
assertThat(terms.getName(), equalTo("terms"));
assertThat(terms.getBuckets().size(), equalTo(5));

double higestSortValue = 0;
for (int i = 0; i < 5; i++) {
Terms.Bucket bucket = terms.getBucketByKey("val" + i);
assertThat(bucket, notNullValue());
assertThat(key(bucket), equalTo("val" + i));
assertThat(bucket.getDocCount(), equalTo(10L));
TopHits topHits = bucket.getAggregations().get("hits");
SearchHits hits = topHits.getHits();
assertThat(hits.getTotalHits().value(), equalTo(10L));
assertThat(hits.getHits().length, equalTo(3));
higestSortValue += 10;
assertThat((Double) hits.getAt(0).getSortValues()[0], equalTo(higestSortValue));
assertThat((Double) hits.getAt(1).getSortValues()[0], equalTo(higestSortValue - 1));
assertThat((Double) hits.getAt(2).getSortValues()[0], equalTo(higestSortValue - 2));

assertThat(hits.getAt(0).getSourceAsMap().size(), equalTo(5));
}
}
);
}

public static class FetchPlugin extends Plugin implements SearchPlugin {
@Override
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ public static class CustomScriptPlugin extends MockScriptPlugin {
@Override
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
Map<String, Function<Map<String, Object>, Object>> scripts = new HashMap<>();
scripts.put("doc['number'].value", vars -> sortDoubleScript(vars));
scripts.put("doc['keyword'].value", vars -> sortStringScript(vars));
scripts.put("doc['number'].value", CustomScriptPlugin::sortDoubleScript);
scripts.put("doc['keyword'].value", CustomScriptPlugin::sortStringScript);
return scripts;
}

static Double sortDoubleScript(Map<String, Object> vars) {
private static Double sortDoubleScript(Map<String, Object> vars) {
Map<?, ?> doc = (Map) vars.get("doc");
Double index = ((Number) ((ScriptDocValues<?>) doc.get("number")).get(0)).doubleValue();
return index;
Double score = (Double) vars.get("_score");
return ((Number) ((ScriptDocValues<?>) doc.get("number")).get(0)).doubleValue() + score;
}

static String sortStringScript(Map<String, Object> vars) {
private static String sortStringScript(Map<String, Object> vars) {
Map<?, ?> doc = (Map) vars.get("doc");
String value = ((String) ((ScriptDocValues<?>) doc.get("keyword")).get(0));
return value;
Double score = (Double) vars.get("_score");
return ((ScriptDocValues<?>) doc.get("keyword")).get(0) + ",_score=" + score;
}
}

Expand Down Expand Up @@ -1665,14 +1665,14 @@ public void testCustomFormat() throws Exception {
);
}

public void testScriptFieldSort() throws Exception {
public void testScriptFieldSort() {
assertAcked(prepareCreate("test").setMapping("keyword", "type=keyword", "number", "type=integer"));
ensureGreen();
final int numDocs = randomIntBetween(10, 20);
IndexRequestBuilder[] indexReqs = new IndexRequestBuilder[numDocs];
List<String> keywords = new ArrayList<>();
for (int i = 0; i < numDocs; ++i) {
indexReqs[i] = prepareIndex("test").setSource("number", i, "keyword", Integer.toString(i));
indexReqs[i] = prepareIndex("test").setSource("number", i, "keyword", Integer.toString(i), "version", i + "." + i);
keywords.add(Integer.toString(i));
}
Collections.sort(keywords);
Expand All @@ -1686,7 +1686,7 @@ public void testScriptFieldSort() throws Exception {
.addSort(SortBuilders.scriptSort(script, ScriptSortBuilder.ScriptSortType.NUMBER))
.addSort(SortBuilders.scoreSort()),
response -> {
double expectedValue = 0;
double expectedValue = 1; // start from 1 because it includes _score, 1.0f for all docs
for (SearchHit hit : response.getHits()) {
assertThat(hit.getSortValues().length, equalTo(2));
assertThat(hit.getSortValues()[0], equalTo(expectedValue++));
Expand All @@ -1707,7 +1707,7 @@ public void testScriptFieldSort() throws Exception {
int expectedValue = 0;
for (SearchHit hit : response.getHits()) {
assertThat(hit.getSortValues().length, equalTo(2));
assertThat(hit.getSortValues()[0], equalTo(keywords.get(expectedValue++)));
assertThat(hit.getSortValues()[0], equalTo(keywords.get(expectedValue++) + ",_score=1.0"));
assertThat(hit.getSortValues()[1], equalTo(1f));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SortField;
Expand Down Expand Up @@ -67,7 +68,7 @@ protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOEx
return indexFieldData.load(context).getBytesValues();
}

protected void setScorer(Scorable scorer) {}
protected void setScorer(LeafReaderContext context, Scorable scorer) {}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
Expand Down Expand Up @@ -120,10 +121,43 @@ protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String f
}

@Override
public void setScorer(Scorable scorer) {
BytesRefFieldComparatorSource.this.setScorer(scorer);
}
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
LeafFieldComparator leafComparator = super.getLeafComparator(context);
// TopFieldCollector interacts with inter-segment concurrency by creating a FieldValueHitQueue per slice, each one with a
// specific instance of the FieldComparator. This ensures sequential execution across LeafFieldComparators returned by
// the same parent FieldComparator. That allows for effectively sharing the same instance of leaf comparator, like in this
// case in the Lucene code. That's fine dealing with sorting by field, but not when using script sorting, because we then
// need to set to Scorer to the specific leaf comparator, to make the _score variable available in sort scripts. The
// setScorer call happens concurrently across slices and needs to target the specific leaf context that is being searched.
return new LeafFieldComparator() {
@Override
public void setBottom(int slot) throws IOException {
leafComparator.setBottom(slot);
}

@Override
public int compareBottom(int doc) throws IOException {
return leafComparator.compareBottom(doc);
}

@Override
public int compareTop(int doc) throws IOException {
return leafComparator.compareTop(doc);
}

@Override
public void copy(int slot, int doc) throws IOException {
leafComparator.copy(slot, doc);
}

@Override
public void setScorer(Scorable scorer) {
// this ensures that the scorer is set for the specific leaf comparator
// corresponding to the leaf context we are scoring
BytesRefFieldComparatorSource.this.setScorer(context, scorer);
}
};
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private NumericDoubleValues getNumericDocValues(LeafReaderContext context, doubl
}
}

protected void setScorer(Scorable scorer) {}
protected void setScorer(LeafReaderContext context, Scorable scorer) {}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
Expand All @@ -91,7 +91,7 @@ protected NumericDocValues getNumericDocValues(LeafReaderContext context, String

@Override
public void setScorer(Scorable scorer) {
DoubleValuesComparatorSource.this.setScorer(scorer);
DoubleValuesComparatorSource.this.setScorer(context, scorer);
}
};
}
Expand Down
Loading