1111import org .apache .lucene .util .RamUsageEstimator ;
1212import org .elasticsearch .common .unit .ByteSizeValue ;
1313import org .elasticsearch .compute .lucene .ShardRefCounted ;
14+ import org .elasticsearch .core .RefCounted ;
1415import org .elasticsearch .core .ReleasableIterator ;
1516import org .elasticsearch .core .Releasables ;
1617
18+ import java .util .BitSet ;
1719import java .util .Objects ;
20+ import java .util .function .Consumer ;
1821
1922/**
2023 * {@link Vector} where each entry references a lucene document.
@@ -30,6 +33,7 @@ public final class DocVector extends AbstractVector implements Vector {
3033 public static final int SHARD_SEGMENT_DOC_MAP_PER_ROW_OVERHEAD = Integer .BYTES * 2 ;
3134
3235 private final IntVector shards ;
36+ private final IntVector uniqueShards ;
3337 private final IntVector segments ;
3438 private final IntVector docs ;
3539
@@ -80,7 +84,32 @@ public DocVector(
8084 }
8185 blockFactory ().adjustBreaker (BASE_RAM_BYTES_USED );
8286
83- forEachShardRefCounter (DecOrInc .INC );
87+ forEachShardRefCounter (RefCounted ::mustIncRef );
88+ this .uniqueShards = computeUniqueShards (shards );
89+ }
90+
91+ private static IntVector computeUniqueShards (IntVector shards ) {
92+ switch (shards ) {
93+ case ConstantIntVector constantIntVector -> {
94+ return shards .blockFactory ().newConstantIntVector (constantIntVector .getInt (0 ), 1 );
95+ }
96+ case ConstantNullVector unused -> {
97+ return shards .blockFactory ().newConstantIntVector (0 , 0 );
98+ }
99+ default -> {
100+ var seen = new BitSet (128 );
101+ try (IntVector .Builder uniqueShardsBuilder = shards .blockFactory ().newIntVectorBuilder (shards .getPositionCount ())) {
102+ for (int p = 0 ; p < shards .getPositionCount (); p ++) {
103+ int shardId = shards .getInt (p );
104+ if (seen .get (shardId ) == false ) {
105+ seen .set (shardId );
106+ uniqueShardsBuilder .appendInt (shardId );
107+ }
108+ }
109+ return uniqueShardsBuilder .build ();
110+ }
111+ }
112+ }
84113 }
85114
86115 public DocVector (
@@ -310,18 +339,20 @@ private static long ramBytesOrZero(int[] array) {
310339
311340 public static long ramBytesEstimated (
312341 IntVector shards ,
342+ IntVector uniqueShards ,
313343 IntVector segments ,
314344 IntVector docs ,
315345 int [] shardSegmentDocMapForwards ,
316346 int [] shardSegmentDocMapBackwards
317347 ) {
318- return BASE_RAM_BYTES_USED + RamUsageEstimator .sizeOf (shards ) + RamUsageEstimator .sizeOf (segments ) + RamUsageEstimator .sizeOf (docs )
319- + ramBytesOrZero (shardSegmentDocMapForwards ) + ramBytesOrZero (shardSegmentDocMapBackwards );
348+ return BASE_RAM_BYTES_USED + RamUsageEstimator .sizeOf (shards ) + RamUsageEstimator .sizeOf (uniqueShards ) + RamUsageEstimator .sizeOf (
349+ segments
350+ ) + RamUsageEstimator .sizeOf (docs ) + ramBytesOrZero (shardSegmentDocMapForwards ) + ramBytesOrZero (shardSegmentDocMapBackwards );
320351 }
321352
322353 @ Override
323354 public long ramBytesUsed () {
324- return ramBytesEstimated (shards , segments , docs , shardSegmentDocMapForwards , shardSegmentDocMapBackwards );
355+ return ramBytesEstimated (shards , uniqueShards , segments , docs , shardSegmentDocMapForwards , shardSegmentDocMapBackwards );
325356 }
326357
327358 @ Override
@@ -337,33 +368,22 @@ public void closeInternal() {
337368 Releasables .closeExpectNoException (
338369 () -> blockFactory ().adjustBreaker (-BASE_RAM_BYTES_USED - (shardSegmentDocMapForwards == null ? 0 : sizeOfSegmentDocMap ())),
339370 shards ,
371+ uniqueShards ,
340372 segments ,
341373 docs
342374 );
343- forEachShardRefCounter (DecOrInc .DEC );
344- }
345-
346- private enum DecOrInc {
347- DEC ,
348- INC ;
349-
350- void apply (ShardRefCounted counters , int shardId ) {
351- switch (this ) {
352- case DEC -> counters .get (shardId ).decRef ();
353- case INC -> counters .get (shardId ).mustIncRef ();
354- }
355- }
375+ forEachShardRefCounter (RefCounted ::decRef );
356376 }
357377
358- private void forEachShardRefCounter (DecOrInc mode ) {
378+ private void forEachShardRefCounter (Consumer < RefCounted > consumer ) {
359379 switch (shards ) {
360- case ConstantIntVector constantIntVector -> mode . apply (shardRefCounters , constantIntVector .getInt (0 ));
380+ case ConstantIntVector constantIntVector -> consumer . accept (shardRefCounters . get ( constantIntVector .getInt (0 ) ));
361381 case ConstantNullVector ignored -> {
362382 // Noop
363383 }
364384 default -> {
365- for (int i = 0 ; i < shards .getPositionCount (); i ++) {
366- mode . apply (shardRefCounters , shards . getInt (i ));
385+ for (int i = 0 ; i < uniqueShards .getPositionCount (); i ++) {
386+ consumer . accept (shardRefCounters . get ( uniqueShards . getInt (i ) ));
367387 }
368388 }
369389 }
0 commit comments