1111import org .apache .lucene .util .RamUsageEstimator ;
1212import org .elasticsearch .common .unit .ByteSizeValue ;
1313import org .elasticsearch .compute .lucene .ShardRefCounted ;
14+ import org .elasticsearch .core .RefCounted ;
1415import org .elasticsearch .core .ReleasableIterator ;
1516import org .elasticsearch .core .Releasables ;
1617
18+ import java .util .BitSet ;
1719import java .util .Objects ;
20+ import java .util .function .Consumer ;
1821
1922/**
2023 * {@link Vector} where each entry references a lucene document.
@@ -30,6 +33,7 @@ public final class DocVector extends AbstractVector implements Vector {
3033 public static final int SHARD_SEGMENT_DOC_MAP_PER_ROW_OVERHEAD = Integer .BYTES * 2 ;
3134
3235 private final IntVector shards ;
36+ private final IntVector uniqueShards ;
3337 private final IntVector segments ;
3438 private final IntVector docs ;
3539
@@ -65,6 +69,7 @@ public DocVector(
6569 super (shards .getPositionCount (), shards .blockFactory ());
6670 this .shardRefCounters = shardRefCounters ;
6771 this .shards = shards ;
72+ this .uniqueShards = computeUniqueShards (shards );
6873 this .segments = segments ;
6974 this .docs = docs ;
7075 this .singleSegmentNonDecreasing = singleSegmentNonDecreasing ;
@@ -80,7 +85,31 @@ public DocVector(
8085 }
8186 blockFactory ().adjustBreaker (BASE_RAM_BYTES_USED );
8287
83- forEachShardRefCounter (DecOrInc .INC );
88+ forEachShardRefCounter (RefCounted ::mustIncRef );
89+ }
90+
91+ private static IntVector computeUniqueShards (IntVector shards ) {
92+ switch (shards ) {
93+ case ConstantIntVector constantIntVector -> {
94+ return shards .blockFactory ().newConstantIntVector (constantIntVector .getInt (0 ), 1 );
95+ }
96+ case ConstantNullVector unused -> {
97+ return shards .blockFactory ().newConstantIntVector (0 , 0 );
98+ }
99+ default -> {
100+ var seen = new BitSet (128 );
101+ try (IntVector .Builder uniqueShardsBuilder = shards .blockFactory ().newIntVectorBuilder (shards .getPositionCount ())) {
102+ for (int p = 0 ; p < shards .getPositionCount (); p ++) {
103+ int shardId = shards .getInt (p );
104+ if (seen .get (shardId ) == false ) {
105+ seen .set (shardId );
106+ uniqueShardsBuilder .appendInt (shardId );
107+ }
108+ }
109+ return uniqueShardsBuilder .build ();
110+ }
111+ }
112+ }
84113 }
85114
86115 public DocVector (
@@ -337,33 +366,26 @@ public void closeInternal() {
337366 Releasables .closeExpectNoException (
338367 () -> blockFactory ().adjustBreaker (-BASE_RAM_BYTES_USED - (shardSegmentDocMapForwards == null ? 0 : sizeOfSegmentDocMap ())),
339368 shards ,
369+ uniqueShards ,
340370 segments ,
341371 docs
342372 );
343- forEachShardRefCounter (DecOrInc . DEC );
373+ forEachShardRefCounter (RefCounted :: decRef );
344374 }
345375
346- private enum DecOrInc {
347- DEC ,
348- INC ;
349-
350- void apply (ShardRefCounted counters , int shardId ) {
351- switch (this ) {
352- case DEC -> counters .get (shardId ).decRef ();
353- case INC -> counters .get (shardId ).mustIncRef ();
354- }
355- }
376+ private interface UpdateShardRefCounted {
377+ void apply (RefCounted rc );
356378 }
357379
358- private void forEachShardRefCounter (DecOrInc mode ) {
380+ private void forEachShardRefCounter (Consumer < RefCounted > consumer ) {
359381 switch (shards ) {
360- case ConstantIntVector constantIntVector -> mode . apply (shardRefCounters , constantIntVector .getInt (0 ));
382+ case ConstantIntVector constantIntVector -> consumer . accept (shardRefCounters . get ( constantIntVector .getInt (0 ) ));
361383 case ConstantNullVector ignored -> {
362384 // Noop
363385 }
364386 default -> {
365- for (int i = 0 ; i < shards .getPositionCount (); i ++) {
366- mode . apply (shardRefCounters , shards . getInt (i ));
387+ for (int i = 0 ; i < uniqueShards .getPositionCount (); i ++) {
388+ consumer . accept (shardRefCounters . get ( uniqueShards . getInt (i ) ));
367389 }
368390 }
369391 }
0 commit comments