diff --git a/docs/changelog/123757.yaml b/docs/changelog/123757.yaml new file mode 100644 index 0000000000000..5f29c43e8121d --- /dev/null +++ b/docs/changelog/123757.yaml @@ -0,0 +1,5 @@ +pr: 123757 +summary: Fix concurrency issue in `ScriptSortBuilder` +area: Search +type: bug +issues: [] diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml index 5674d79b52a94..a7cf0a6bf9592 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml @@ -482,3 +482,118 @@ }] - match: { error.root_cause.0.type: "illegal_argument_exception" } - match: { error.root_cause.0.reason: "script score function must not produce negative scores, but got: [-9.0]"} + +--- + +"Script Sort + _score": + - do: + index: + index: test + id: "1" + body: { "test": "a", "num1": 1.0, "type" : "first" } + - do: + index: + index: test + id: "2" + body: { "test": "b", "num1": 2.0, "type" : "first" } + - do: + index: + index: test + id: "3" + body: { "test": "c", "num1": 3.0, "type" : "first" } + - do: + index: + index: test + id: "4" + body: { "test": "d", "num1": 4.0, "type" : "second" } + - do: + index: + index: test + id: "5" + body: { "test": "e", "num1": 5.0, "type" : "second" } + - do: + indices.refresh: {} + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + sort: [ + { + _script: { + script: { + lang: "painless", + source: "doc['num1'].value + _score" + }, + type: "number" + } + } + ] + + - match: { hits.total: 5 } + - match: { hits.hits.0.sort.0: 2.0 } + - match: { hits.hits.1.sort.0: 3.0 } + - match: { hits.hits.2.sort.0: 4.0 } + - match: { hits.hits.3.sort.0: 5.0 } + - match: { hits.hits.4.sort.0: 6.0 } + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + sort: [ + { + _script: { + script: { + lang: "painless", + source: "doc['test.keyword'].value + '-' + _score" + }, + type: "string" + } + } + ] + + - match: { hits.total: 5 } + - match: { hits.hits.0.sort.0: "a-1.0" } + - match: { hits.hits.1.sort.0: "b-1.0" } + - match: { hits.hits.2.sort.0: "c-1.0" } + - match: { hits.hits.3.sort.0: "d-1.0" } + - match: { hits.hits.4.sort.0: "e-1.0" } + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + aggs: + test: + terms: + field: type.keyword + aggs: + top_hits: + top_hits: + sort: [ + { + _script: { + script: { + lang: "painless", + source: "doc['test.keyword'].value + '-' + _score" + }, + type: "string" + } + }, + "_score" + ] + size: 1 + + - match: { hits.total: 5 } + - match: { aggregations.test.buckets.0.key: "first" } + - match: { aggregations.test.buckets.0.top_hits.hits.total: 3 } + - match: { aggregations.test.buckets.0.top_hits.hits.hits.0.sort.0: "a-1.0" } + - match: { aggregations.test.buckets.0.top_hits.hits.hits.0.sort.1: 1.0 } + - match: { aggregations.test.buckets.1.key: "second" } + - match: { aggregations.test.buckets.1.top_hits.hits.total: 2 } + - match: { aggregations.test.buckets.1.top_hits.hits.hits.0.sort.0: "d-1.0" } + - match: { aggregations.test.buckets.1.top_hits.hits.hits.0.sort.1: 1.0 } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java index c246b7cc2f5cc..781c399465cf3 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/metrics/TopHitsIT.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.seqno.SequenceNumbers; @@ -64,6 +65,7 @@ import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; import static org.elasticsearch.index.query.QueryBuilders.matchQuery; import static org.elasticsearch.index.query.QueryBuilders.nestedQuery; +import static org.elasticsearch.script.MockScriptPlugin.NAME; import static org.elasticsearch.search.aggregations.AggregationBuilders.global; import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram; import static org.elasticsearch.search.aggregations.AggregationBuilders.max; @@ -102,7 +104,12 @@ protected Collection> nodePlugins() { public static class CustomScriptPlugin extends MockScriptPlugin { @Override protected Map, Object>> pluginScripts() { - return Collections.singletonMap("5", script -> "5"); + return Map.of("5", script -> "5", "doc['sort'].value", CustomScriptPlugin::sortDoubleScript); + } + + private static Double sortDoubleScript(Map vars) { + Map doc = (Map) vars.get("doc"); + return ((Number) ((ScriptDocValues) doc.get("sort")).get(0)).doubleValue(); } @Override @@ -1268,6 +1275,41 @@ public void testWithRescore() { ); } + public void testScriptSorting() { + Script script = new Script(ScriptType.INLINE, NAME, "doc['sort'].value", Collections.emptyMap()); + assertNoFailuresAndResponse( + prepareSearch("idx").addAggregation( + terms("terms").executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD) + .subAggregation(topHits("hits").sort(SortBuilders.scriptSort(script, ScriptSortType.NUMBER).order(SortOrder.DESC))) + ), + response -> { + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + assertThat(terms.getName(), equalTo("terms")); + assertThat(terms.getBuckets().size(), equalTo(5)); + + double higestSortValue = 0; + for (int i = 0; i < 5; i++) { + Terms.Bucket bucket = terms.getBucketByKey("val" + i); + assertThat(bucket, notNullValue()); + assertThat(key(bucket), equalTo("val" + i)); + assertThat(bucket.getDocCount(), equalTo(10L)); + TopHits topHits = bucket.getAggregations().get("hits"); + SearchHits hits = topHits.getHits(); + assertThat(hits.getTotalHits().value(), equalTo(10L)); + assertThat(hits.getHits().length, equalTo(3)); + higestSortValue += 10; + assertThat((Double) hits.getAt(0).getSortValues()[0], equalTo(higestSortValue)); + assertThat((Double) hits.getAt(1).getSortValues()[0], equalTo(higestSortValue - 1)); + assertThat((Double) hits.getAt(2).getSortValues()[0], equalTo(higestSortValue - 2)); + + assertThat(hits.getAt(0).getSourceAsMap().size(), equalTo(5)); + } + } + ); + } + public static class FetchPlugin extends Plugin implements SearchPlugin { @Override public List getFetchSubPhases(FetchPhaseConstructionContext context) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java index 7fd31b056779c..1d9bc96582ffb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java @@ -87,21 +87,21 @@ public static class CustomScriptPlugin extends MockScriptPlugin { @Override protected Map, Object>> pluginScripts() { Map, Object>> scripts = new HashMap<>(); - scripts.put("doc['number'].value", vars -> sortDoubleScript(vars)); - scripts.put("doc['keyword'].value", vars -> sortStringScript(vars)); + scripts.put("doc['number'].value", CustomScriptPlugin::sortDoubleScript); + scripts.put("doc['keyword'].value", CustomScriptPlugin::sortStringScript); return scripts; } - static Double sortDoubleScript(Map vars) { + private static Double sortDoubleScript(Map vars) { Map doc = (Map) vars.get("doc"); - Double index = ((Number) ((ScriptDocValues) doc.get("number")).get(0)).doubleValue(); - return index; + Double score = (Double) vars.get("_score"); + return ((Number) ((ScriptDocValues) doc.get("number")).get(0)).doubleValue() + score; } - static String sortStringScript(Map vars) { + private static String sortStringScript(Map vars) { Map doc = (Map) vars.get("doc"); - String value = ((String) ((ScriptDocValues) doc.get("keyword")).get(0)); - return value; + Double score = (Double) vars.get("_score"); + return ((ScriptDocValues) doc.get("keyword")).get(0) + ",_score=" + score; } } @@ -1665,14 +1665,14 @@ public void testCustomFormat() throws Exception { ); } - public void testScriptFieldSort() throws Exception { + public void testScriptFieldSort() { assertAcked(prepareCreate("test").setMapping("keyword", "type=keyword", "number", "type=integer")); ensureGreen(); final int numDocs = randomIntBetween(10, 20); IndexRequestBuilder[] indexReqs = new IndexRequestBuilder[numDocs]; List keywords = new ArrayList<>(); for (int i = 0; i < numDocs; ++i) { - indexReqs[i] = prepareIndex("test").setSource("number", i, "keyword", Integer.toString(i)); + indexReqs[i] = prepareIndex("test").setSource("number", i, "keyword", Integer.toString(i), "version", i + "." + i); keywords.add(Integer.toString(i)); } Collections.sort(keywords); @@ -1686,7 +1686,7 @@ public void testScriptFieldSort() throws Exception { .addSort(SortBuilders.scriptSort(script, ScriptSortBuilder.ScriptSortType.NUMBER)) .addSort(SortBuilders.scoreSort()), response -> { - double expectedValue = 0; + double expectedValue = 1; // start from 1 because it includes _score, 1.0f for all docs for (SearchHit hit : response.getHits()) { assertThat(hit.getSortValues().length, equalTo(2)); assertThat(hit.getSortValues()[0], equalTo(expectedValue++)); @@ -1707,7 +1707,7 @@ public void testScriptFieldSort() throws Exception { int expectedValue = 0; for (SearchHit hit : response.getHits()) { assertThat(hit.getSortValues().length, equalTo(2)); - assertThat(hit.getSortValues()[0], equalTo(keywords.get(expectedValue++))); + assertThat(hit.getSortValues()[0], equalTo(keywords.get(expectedValue++) + ",_score=1.0")); assertThat(hit.getSortValues()[1], equalTo(1f)); } } diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java index 7da1e7d8a6790..1a12e4c7733a9 100644 --- a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java @@ -15,6 +15,7 @@ import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.LeafFieldComparator; import org.apache.lucene.search.Pruning; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.SortField; @@ -67,7 +68,7 @@ protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOEx return indexFieldData.load(context).getBytesValues(); } - protected void setScorer(Scorable scorer) {} + protected void setScorer(LeafReaderContext context, Scorable scorer) {} @Override public FieldComparator newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) { @@ -120,10 +121,43 @@ protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String f } @Override - public void setScorer(Scorable scorer) { - BytesRefFieldComparatorSource.this.setScorer(scorer); - } + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + LeafFieldComparator leafComparator = super.getLeafComparator(context); + // TopFieldCollector interacts with inter-segment concurrency by creating a FieldValueHitQueue per slice, each one with a + // specific instance of the FieldComparator. This ensures sequential execution across LeafFieldComparators returned by + // the same parent FieldComparator. That allows for effectively sharing the same instance of leaf comparator, like in this + // case in the Lucene code. That's fine dealing with sorting by field, but not when using script sorting, because we then + // need to set to Scorer to the specific leaf comparator, to make the _score variable available in sort scripts. The + // setScorer call happens concurrently across slices and needs to target the specific leaf context that is being searched. + return new LeafFieldComparator() { + @Override + public void setBottom(int slot) throws IOException { + leafComparator.setBottom(slot); + } + + @Override + public int compareBottom(int doc) throws IOException { + return leafComparator.compareBottom(doc); + } + + @Override + public int compareTop(int doc) throws IOException { + return leafComparator.compareTop(doc); + } + + @Override + public void copy(int slot, int doc) throws IOException { + leafComparator.copy(slot, doc); + } + @Override + public void setScorer(Scorable scorer) { + // this ensures that the scorer is set for the specific leaf comparator + // corresponding to the leaf context we are scoring + BytesRefFieldComparatorSource.this.setScorer(context, scorer); + } + }; + } }; } diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java index ae9ec46cf152a..c5fcb0207ce4d 100644 --- a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java @@ -71,7 +71,7 @@ private NumericDoubleValues getNumericDocValues(LeafReaderContext context, doubl } } - protected void setScorer(Scorable scorer) {} + protected void setScorer(LeafReaderContext context, Scorable scorer) {} @Override public FieldComparator newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) { @@ -91,7 +91,7 @@ protected NumericDocValues getNumericDocValues(LeafReaderContext context, String @Override public void setScorer(Scorable scorer) { - DoubleValuesComparatorSource.this.setScorer(scorer); + DoubleValuesComparatorSource.this.setScorer(context, scorer); } }; } diff --git a/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java b/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java index 445c55dc546bc..b3c88be60c179 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.index.fielddata.AbstractBinaryDocValues; import org.elasticsearch.index.fielddata.FieldData; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -52,6 +53,7 @@ import java.io.IOException; import java.util.Locale; +import java.util.Map; import java.util.Objects; import static org.elasticsearch.search.sort.FieldSortBuilder.validateMaxChildrenExistOnlyInTopLevelNestedSort; @@ -278,11 +280,13 @@ private IndexFieldData.XFieldComparatorSource fieldComparatorSource(SearchExecut final StringSortScript.Factory factory = context.compile(script, StringSortScript.CONTEXT); final StringSortScript.LeafFactory searchScript = factory.newFactory(script.getParams()); return new BytesRefFieldComparatorSource(null, null, valueMode, nested) { - StringSortScript leafScript; + final Map leafScripts = ConcurrentCollections.newConcurrentMap(); @Override protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOException { - leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context)); + // we may see the same leaf context multiple times, and each time we need to refresh the doc values doc reader + StringSortScript leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context)); + leafScripts.put(context.id(), leafScript); final BinaryDocValues values = new AbstractBinaryDocValues() { final BytesRefBuilder spare = new BytesRefBuilder(); @@ -302,8 +306,8 @@ public BytesRef binaryValue() { } @Override - protected void setScorer(Scorable scorer) { - leafScript.setScorer(scorer); + protected void setScorer(LeafReaderContext context, Scorable scorer) { + leafScripts.get(context.id()).setScorer(scorer); } @Override @@ -326,13 +330,15 @@ public BucketedSort newBucketedSort( case NUMBER -> { final NumberSortScript.Factory numberSortFactory = context.compile(script, NumberSortScript.CONTEXT); // searchLookup is unnecessary here, as it's just used for expressions - final NumberSortScript.LeafFactory numberSortScript = numberSortFactory.newFactory(script.getParams(), searchLookup); + final NumberSortScript.LeafFactory numberSortScriptFactory = numberSortFactory.newFactory(script.getParams(), searchLookup); return new DoubleValuesComparatorSource(null, Double.MAX_VALUE, valueMode, nested) { - NumberSortScript leafScript; + final Map leafScripts = ConcurrentCollections.newConcurrentMap(); @Override protected SortedNumericDoubleValues getValues(LeafReaderContext context) throws IOException { - leafScript = numberSortScript.newInstance(new DocValuesDocReader(searchLookup, context)); + // we may see the same leaf context multiple times, and each time we need to refresh the doc values doc reader + NumberSortScript leafScript = numberSortScriptFactory.newInstance(new DocValuesDocReader(searchLookup, context)); + leafScripts.put(context.id(), leafScript); final NumericDoubleValues values = new NumericDoubleValues() { @Override public boolean advanceExact(int doc) { @@ -349,8 +355,8 @@ public double doubleValue() { } @Override - protected void setScorer(Scorable scorer) { - leafScript.setScorer(scorer); + protected void setScorer(LeafReaderContext context, Scorable scorer) { + leafScripts.get(context.id()).setScorer(scorer); } }; } @@ -358,11 +364,13 @@ protected void setScorer(Scorable scorer) { final BytesRefSortScript.Factory factory = context.compile(script, BytesRefSortScript.CONTEXT); final BytesRefSortScript.LeafFactory searchScript = factory.newFactory(script.getParams()); return new BytesRefFieldComparatorSource(null, null, valueMode, nested) { - BytesRefSortScript leafScript; + final Map leafScripts = ConcurrentCollections.newConcurrentMap(); @Override protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOException { - leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context)); + // we may see the same leaf context multiple times, and each time we need to refresh the doc values doc reader + BytesRefSortScript leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context)); + leafScripts.put(context.id(), leafScript); final BinaryDocValues values = new AbstractBinaryDocValues() { @Override @@ -391,8 +399,8 @@ public BytesRef binaryValue() { } @Override - protected void setScorer(Scorable scorer) { - leafScript.setScorer(scorer); + protected void setScorer(LeafReaderContext context, Scorable scorer) { + leafScripts.get(context.id()).setScorer(scorer); } @Override @@ -494,4 +502,9 @@ public ScriptSortBuilder rewrite(QueryRewriteContext ctx) throws IOException { } return new ScriptSortBuilder(this).setNestedSort(rewrite); } + + @Override + public boolean supportsParallelCollection() { + return true; + } } diff --git a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index 35777183ac18d..95217b6a005c4 100644 --- a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -1023,10 +1023,10 @@ public void testSupportsParallelCollection() { SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder.get(); searchSourceBuilder.aggregation( new TopHitsAggregationBuilder("terms").sort( - SortBuilders.scriptSort(new Script("id"), ScriptSortBuilder.ScriptSortType.NUMBER) + SortBuilders.scriptSort(new Script("id"), randomFrom(ScriptSortBuilder.ScriptSortType.values())) ) ); - assertFalse(searchSourceBuilder.supportsParallelCollection(fieldCardinality)); + assertTrue(searchSourceBuilder.supportsParallelCollection(fieldCardinality)); } { SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder.get(); @@ -1047,7 +1047,7 @@ public void testSupportsParallelCollection() { ScriptSortBuilder.ScriptSortType.NUMBER ).order(randomFrom(SortOrder.values())) ); - assertFalse(searchSourceBuilder.supportsParallelCollection(fieldCardinality)); + assertTrue(searchSourceBuilder.supportsParallelCollection(fieldCardinality)); } { SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder.get(); diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index 1e82313338b97..24d46b99b541b 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -137,6 +137,12 @@ public double execute() { Map vars = new HashMap<>(parameters); vars.put("params", parameters); vars.put("doc", getDoc()); + try { + vars.put("_score", get_score()); + } catch (Exception ignore) { + // nothing to do: if get_score throws we don't set the _score, likely the scorer is null, + // which is ok if _score was not requested e.g. top_hits. + } return ((Number) script.apply(vars)).doubleValue(); } }; @@ -881,6 +887,12 @@ public String execute() { Map vars = new HashMap<>(parameters); vars.put("params", parameters); vars.put("doc", getDoc()); + try { + vars.put("_score", get_score()); + } catch (Exception ignore) { + // nothing to do: if get_score throws we don't set the _score, likely the scorer is null, + // which is ok if _score was not requested e.g. top_hits. + } return String.valueOf(script.apply(vars)); } }; @@ -907,6 +919,12 @@ public BytesRefProducer execute() { Map vars = new HashMap<>(parameters); vars.put("params", parameters); vars.put("doc", getDoc()); + try { + vars.put("_score", get_score()); + } catch (Exception ignore) { + // nothing to do: if get_score throws we don't set the _score, likely the scorer is null, + // which is ok if _score was not requested e.g. top_hits. + } return (BytesRefProducer) script.apply(vars); } };