Skip to content

Commit 7b37e1f

Browse files
committed
Fix concurrency issue in ScriptSortBuilder
1 parent 79c388a commit 7b37e1f

File tree

3 files changed

+85
-18
lines changed

3 files changed

+85
-18
lines changed

server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.apache.lucene.index.SortedSetDocValues;
1616
import org.apache.lucene.search.DocIdSetIterator;
1717
import org.apache.lucene.search.FieldComparator;
18+
import org.apache.lucene.search.LeafFieldComparator;
1819
import org.apache.lucene.search.Pruning;
1920
import org.apache.lucene.search.Scorable;
2021
import org.apache.lucene.search.SortField;
@@ -67,7 +68,7 @@ protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOEx
6768
return indexFieldData.load(context).getBytesValues();
6869
}
6970

70-
protected void setScorer(Scorable scorer) {}
71+
protected void setScorer(LeafReaderContext context, Scorable scorer) {}
7172

7273
@Override
7374
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
@@ -120,10 +121,38 @@ protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String f
120121
}
121122

122123
@Override
123-
public void setScorer(Scorable scorer) {
124-
BytesRefFieldComparatorSource.this.setScorer(scorer);
125-
}
124+
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
125+
LeafFieldComparator leafComparator = super.getLeafComparator(context);
126+
return new LeafFieldComparator() {
127+
@Override
128+
public void setBottom(int slot) throws IOException {
129+
leafComparator.setBottom(slot);
130+
}
131+
132+
@Override
133+
public int compareBottom(int doc) throws IOException {
134+
return leafComparator.compareBottom(doc);
135+
}
136+
137+
@Override
138+
public int compareTop(int doc) throws IOException {
139+
return leafComparator.compareTop(doc);
140+
}
141+
142+
@Override
143+
public void copy(int slot, int doc) throws IOException {
144+
leafComparator.copy(slot, doc);
145+
}
126146

147+
@Override
148+
public void setScorer(Scorable scorer) throws IOException {
149+
// this ensures that the scorer is set for the specific leaf comparator
150+
// corresponding to the leaf context we are scoring
151+
// BytesRefFieldComparatorSource.this.setScorer(context, scorer);
152+
// TODO just an experiment to make sure that some test fails without it!
153+
}
154+
};
155+
}
127156
};
128157
}
129158

server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ private NumericDoubleValues getNumericDocValues(LeafReaderContext context, doubl
7171
}
7272
}
7373

74-
protected void setScorer(Scorable scorer) {}
74+
protected void setScorer(LeafReaderContext context, Scorable scorer) {}
7575

7676
@Override
7777
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
@@ -91,7 +91,7 @@ protected NumericDocValues getNumericDocValues(LeafReaderContext context, String
9191

9292
@Override
9393
public void setScorer(Scorable scorer) {
94-
DoubleValuesComparatorSource.this.setScorer(scorer);
94+
DoubleValuesComparatorSource.this.setScorer(context, scorer);
9595
}
9696
};
9797
}

server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.common.io.stream.StreamOutput;
2222
import org.elasticsearch.common.io.stream.Writeable;
2323
import org.elasticsearch.common.util.BigArrays;
24+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
2425
import org.elasticsearch.index.fielddata.AbstractBinaryDocValues;
2526
import org.elasticsearch.index.fielddata.FieldData;
2627
import org.elasticsearch.index.fielddata.IndexFieldData;
@@ -51,7 +52,9 @@
5152
import org.elasticsearch.xcontent.XContentParser;
5253

5354
import java.io.IOException;
55+
import java.io.UncheckedIOException;
5456
import java.util.Locale;
57+
import java.util.Map;
5558
import java.util.Objects;
5659

5760
import static org.elasticsearch.search.sort.FieldSortBuilder.validateMaxChildrenExistOnlyInTopLevelNestedSort;
@@ -278,11 +281,11 @@ private IndexFieldData.XFieldComparatorSource fieldComparatorSource(SearchExecut
278281
final StringSortScript.Factory factory = context.compile(script, StringSortScript.CONTEXT);
279282
final StringSortScript.LeafFactory searchScript = factory.newFactory(script.getParams());
280283
return new BytesRefFieldComparatorSource(null, null, valueMode, nested) {
281-
StringSortScript leafScript;
284+
final Map<Object, StringSortScript> leafScripts = ConcurrentCollections.newConcurrentMap();
282285

283286
@Override
284287
protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOException {
285-
leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context));
288+
StringSortScript leafScript = getLeafScript(context);
286289
final BinaryDocValues values = new AbstractBinaryDocValues() {
287290
final BytesRefBuilder spare = new BytesRefBuilder();
288291

@@ -302,8 +305,18 @@ public BytesRef binaryValue() {
302305
}
303306

304307
@Override
305-
protected void setScorer(Scorable scorer) {
306-
leafScript.setScorer(scorer);
308+
protected void setScorer(LeafReaderContext context, Scorable scorer) {
309+
getLeafScript(context).setScorer(scorer);
310+
}
311+
312+
StringSortScript getLeafScript(LeafReaderContext context) {
313+
return leafScripts.computeIfAbsent(context.id(), o -> {
314+
try {
315+
return searchScript.newInstance(new DocValuesDocReader(searchLookup, context));
316+
} catch (IOException e) {
317+
throw new UncheckedIOException(e);
318+
}
319+
});
307320
}
308321

309322
@Override
@@ -328,11 +341,11 @@ public BucketedSort newBucketedSort(
328341
// searchLookup is unnecessary here, as it's just used for expressions
329342
final NumberSortScript.LeafFactory numberSortScript = numberSortFactory.newFactory(script.getParams(), searchLookup);
330343
return new DoubleValuesComparatorSource(null, Double.MAX_VALUE, valueMode, nested) {
331-
NumberSortScript leafScript;
344+
final Map<Object, NumberSortScript> leafScripts = ConcurrentCollections.newConcurrentMap();
332345

333346
@Override
334347
protected SortedNumericDoubleValues getValues(LeafReaderContext context) throws IOException {
335-
leafScript = numberSortScript.newInstance(new DocValuesDocReader(searchLookup, context));
348+
NumberSortScript leafScript = getLeafScript(context);
336349
final NumericDoubleValues values = new NumericDoubleValues() {
337350
@Override
338351
public boolean advanceExact(int doc) {
@@ -349,20 +362,30 @@ public double doubleValue() {
349362
}
350363

351364
@Override
352-
protected void setScorer(Scorable scorer) {
353-
leafScript.setScorer(scorer);
365+
protected void setScorer(LeafReaderContext context, Scorable scorer) {
366+
getLeafScript(context).setScorer(scorer);
367+
}
368+
369+
NumberSortScript getLeafScript(LeafReaderContext context) {
370+
return leafScripts.computeIfAbsent(context.id(), o -> {
371+
try {
372+
return numberSortScript.newInstance(new DocValuesDocReader(searchLookup, context));
373+
} catch (IOException e) {
374+
throw new UncheckedIOException(e);
375+
}
376+
});
354377
}
355378
};
356379
}
357380
case VERSION -> {
358381
final BytesRefSortScript.Factory factory = context.compile(script, BytesRefSortScript.CONTEXT);
359382
final BytesRefSortScript.LeafFactory searchScript = factory.newFactory(script.getParams());
360383
return new BytesRefFieldComparatorSource(null, null, valueMode, nested) {
361-
BytesRefSortScript leafScript;
384+
final Map<Object, BytesRefSortScript> leafScripts = ConcurrentCollections.newConcurrentMap();
362385

363386
@Override
364387
protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOException {
365-
leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context));
388+
BytesRefSortScript leafScript = getLeafScript(context);
366389
final BinaryDocValues values = new AbstractBinaryDocValues() {
367390

368391
@Override
@@ -391,8 +414,18 @@ public BytesRef binaryValue() {
391414
}
392415

393416
@Override
394-
protected void setScorer(Scorable scorer) {
395-
leafScript.setScorer(scorer);
417+
protected void setScorer(LeafReaderContext context, Scorable scorer) {
418+
getLeafScript(context).setScorer(scorer);
419+
}
420+
421+
BytesRefSortScript getLeafScript(LeafReaderContext context) {
422+
return leafScripts.computeIfAbsent(context.id(), o -> {
423+
try {
424+
return searchScript.newInstance(new DocValuesDocReader(searchLookup, context));
425+
} catch (IOException e) {
426+
throw new UncheckedIOException(e);
427+
}
428+
});
396429
}
397430

398431
@Override
@@ -494,4 +527,9 @@ public ScriptSortBuilder rewrite(QueryRewriteContext ctx) throws IOException {
494527
}
495528
return new ScriptSortBuilder(this).setNestedSort(rewrite);
496529
}
530+
531+
@Override
532+
public boolean supportsParallelCollection() {
533+
return true;
534+
}
497535
}

0 commit comments

Comments
 (0)