1313import org .apache .logging .log4j .Logger ;
1414import org .apache .lucene .index .LeafReaderContext ;
1515import org .apache .lucene .search .TotalHits ;
16- import org .elasticsearch .common .MemoryAccountingBytesRefCounted ;
1716import org .elasticsearch .common .breaker .CircuitBreaker ;
1817import org .elasticsearch .common .breaker .CircuitBreakingException ;
1918import org .elasticsearch .common .bytes .BytesReference ;
19+ import org .elasticsearch .core .AbstractRefCounted ;
2020import org .elasticsearch .core .RefCounted ;
2121import org .elasticsearch .index .fieldvisitor .LeafStoredFieldLoader ;
2222import org .elasticsearch .index .fieldvisitor .StoredFieldLoader ;
@@ -291,7 +291,7 @@ private HitContext prepareHitContext(
291291 RankDoc rankDoc ,
292292 CircuitBreaker circuitBreaker ,
293293 boolean submitToCb ,
294- IntBooleanFunction memoryUsageAccumulator
294+ MemoryUsageAccumulator memoryUsageAccumulator
295295 ) throws IOException {
296296 if (nestedDocuments .advance (docId - subReaderContext .docBase ) == null ) {
297297 return prepareNonNestedHitContext (
@@ -339,7 +339,7 @@ private static HitContext prepareNonNestedHitContext(
339339 RankDoc rankDoc ,
340340 CircuitBreaker circuitBreaker ,
341341 boolean submitToCB ,
342- IntBooleanFunction memoryUsageAccumulator
342+ MemoryUsageAccumulator memoryUsageAccumulator
343343 ) throws IOException {
344344 int subDocId = docId - subReaderContext .docBase ;
345345
@@ -368,7 +368,7 @@ private static HitContext prepareNonNestedHitContext(
368368 source = sourceLoader .source (leafStoredFieldLoader , subDocId );
369369 int accumulatedInLeaf = memoryUsageAccumulator .apply (source .internalSourceRef ().length (), submitToCB );
370370 if (submitToCB ) {
371- memAccountingRefCounted .setBytesAndAccount (accumulatedInLeaf , "fetch phase source loader" );
371+ memAccountingRefCounted .account (accumulatedInLeaf , "fetch phase source loader" );
372372 }
373373 } catch (CircuitBreakingException e ) {
374374 hit .decRef ();
@@ -400,7 +400,7 @@ private static Supplier<Source> lazyStoredSourceLoader(
400400 int doc ,
401401 MemoryAccountingBytesRefCounted memAccountingRefCounted ,
402402 boolean submitToCB ,
403- IntBooleanFunction memoryUsageAccumulator
403+ MemoryUsageAccumulator memoryUsageAccumulator
404404 ) {
405405 return () -> {
406406 StoredFieldLoader rootLoader = profiler .storedFields (StoredFieldLoader .create (true , Collections .emptySet ()));
@@ -410,7 +410,7 @@ private static Supplier<Source> lazyStoredSourceLoader(
410410 BytesReference source = leafRootLoader .source ();
411411 int accumulatedInLeaf = memoryUsageAccumulator .apply (source .length (), submitToCB );
412412 if (submitToCB ) {
413- memAccountingRefCounted .setBytesAndAccount (accumulatedInLeaf , "lazy fetch phase source loader" );
413+ memAccountingRefCounted .account (accumulatedInLeaf , "lazy fetch phase source loader" );
414414 }
415415 return Source .fromBytes (source );
416416 } catch (IOException e ) {
@@ -515,7 +515,45 @@ public String toString() {
515515 }
516516
517517 @ FunctionalInterface
518- private interface IntBooleanFunction {
518+ private interface MemoryUsageAccumulator {
519519 int apply (int i , boolean b );
520520 }
521+
522+ /**
523+ * A ref counted object that accounts for memory usage in bytes and releases the
524+ * accounted memory from the circuit breaker when the reference count reaches zero.
525+ */
526+ static final class MemoryAccountingBytesRefCounted extends AbstractRefCounted {
527+
528+ // the bytes that we account for are not volatile because we only accumulate
529+ // in the single threaded fetch phase and we release the reference after
530+ // we write the response to the network (OutboundHandler). As with all other
531+ // SearchHit fields this will be visible to the network thread that'll call #decRef.
532+ private int bytes ;
533+ private final CircuitBreaker breaker ;
534+
535+ private MemoryAccountingBytesRefCounted (CircuitBreaker breaker ) {
536+ this .breaker = breaker ;
537+ }
538+
539+ public static MemoryAccountingBytesRefCounted create (CircuitBreaker breaker ) {
540+ return new MemoryAccountingBytesRefCounted (breaker );
541+ }
542+
543+ /**
544+ * This method increments the local counter for the accounted bytes and submits
545+ * the accumulated bytes to the circuit breaker.
546+ * This method is not thread-safe and should only be called from the single-threaded
547+ * fetch phase.
548+ */
549+ public void account (int bytes , String label ) {
550+ this .bytes += bytes ;
551+ breaker .addEstimateBytesAndMaybeBreak (bytes , label );
552+ }
553+
554+ @ Override
555+ protected void closeInternal () {
556+ breaker .addWithoutBreaking (-bytes );
557+ }
558+ }
521559}
0 commit comments