3636import org .apache .lucene .util .BitSetIterator ;
3737import org .apache .lucene .util .Bits ;
3838import org .apache .lucene .util .SparseFixedBitSet ;
39- import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
4039import org .elasticsearch .core .Releasable ;
4140import org .elasticsearch .lucene .util .CombinedBitSet ;
4241import org .elasticsearch .search .dfs .AggregatedDfs ;
5352import java .util .List ;
5453import java .util .Objects ;
5554import java .util .PriorityQueue ;
56- import java .util .Set ;
5755import java .util .concurrent .Callable ;
5856import java .util .concurrent .Executor ;
5957
@@ -80,7 +78,6 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
8078 // don't create slices with less than this number of docs
8179 private final int minimumDocsPerSlice ;
8280
83- private final Set <Thread > timeoutOverwrites = ConcurrentCollections .newConcurrentSet ();
8481 private volatile boolean timeExceeded = false ;
8582
8683 /** constructor for non-concurrent search */
@@ -356,6 +353,8 @@ private <C extends Collector, T> T search(Weight weight, CollectorManager<C, T>
356353 }
357354 }
358355
356+ private static final ThreadLocal <Boolean > timeoutOverwrites = ThreadLocal .withInitial (() -> false );
357+
359358 /**
360359 * Similar to the lucene implementation, with the following changes made:
361360 * 1) postCollection is performed after each segment is collected. This is needed for aggregations, performed by search threads
@@ -379,12 +378,12 @@ public void search(List<LeafReaderContext> leaves, Weight weight, Collector coll
379378 try {
380379 // Search phase has finished, no longer need to check for timeout
381380 // otherwise the aggregation post-collection phase might get cancelled.
382- boolean added = timeoutOverwrites .add ( Thread . currentThread ()) ;
383- assert added ;
381+ assert timeoutOverwrites .get () == false ;
382+ timeoutOverwrites . set ( true ) ;
384383 doAggregationPostCollection (collector );
385384 } finally {
386- boolean removed = timeoutOverwrites .remove ( Thread . currentThread () );
387- assert removed ;
385+ assert timeoutOverwrites .get ( );
386+ timeoutOverwrites . set ( false ) ;
388387 }
389388 }
390389 }
@@ -402,7 +401,7 @@ public boolean timeExceeded() {
402401 }
403402
404403 public void throwTimeExceededException () {
405- if (timeoutOverwrites .contains ( Thread . currentThread () ) == false ) {
404+ if (timeoutOverwrites .get ( ) == false ) {
406405 throw new TimeExceededException ();
407406 }
408407 }
0 commit comments