Skip to content

Commit 492fb60

Browse files
committed
CR fixes
1 parent 2843e20 commit 492fb60

File tree

5 files changed

+51
-17
lines changed

5 files changed

+51
-17
lines changed

libs/core/src/main/java/org/elasticsearch/core/Releasables.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ public boolean equals(Object obj) {
202202
}
203203
}
204204

205+
/** Creates a {@link Releasable} that calls {@link RefCounted#decRef()} when closed. */
205206
public static Releasable fromRefCounted(RefCounted refCounted) {
206207
return () -> refCounted.decRef();
207208
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
import org.apache.lucene.util.RamUsageEstimator;
1212
import org.elasticsearch.common.unit.ByteSizeValue;
1313
import org.elasticsearch.compute.lucene.ShardRefCounted;
14+
import org.elasticsearch.core.RefCounted;
1415
import org.elasticsearch.core.ReleasableIterator;
1516
import org.elasticsearch.core.Releasables;
1617

18+
import java.util.BitSet;
1719
import java.util.Objects;
20+
import java.util.function.Consumer;
1821

1922
/**
2023
* {@link Vector} where each entry references a lucene document.
@@ -30,6 +33,7 @@ public final class DocVector extends AbstractVector implements Vector {
3033
public static final int SHARD_SEGMENT_DOC_MAP_PER_ROW_OVERHEAD = Integer.BYTES * 2;
3134

3235
private final IntVector shards;
36+
private final IntVector uniqueShards;
3337
private final IntVector segments;
3438
private final IntVector docs;
3539

@@ -65,6 +69,7 @@ public DocVector(
6569
super(shards.getPositionCount(), shards.blockFactory());
6670
this.shardRefCounters = shardRefCounters;
6771
this.shards = shards;
72+
this.uniqueShards = computeUniqueShards(shards);
6873
this.segments = segments;
6974
this.docs = docs;
7075
this.singleSegmentNonDecreasing = singleSegmentNonDecreasing;
@@ -80,7 +85,31 @@ public DocVector(
8085
}
8186
blockFactory().adjustBreaker(BASE_RAM_BYTES_USED);
8287

83-
forEachShardRefCounter(DecOrInc.INC);
88+
forEachShardRefCounter(RefCounted::mustIncRef);
89+
}
90+
91+
private static IntVector computeUniqueShards(IntVector shards) {
92+
switch (shards) {
93+
case ConstantIntVector constantIntVector -> {
94+
return shards.blockFactory().newConstantIntVector(constantIntVector.getInt(0), 1);
95+
}
96+
case ConstantNullVector unused -> {
97+
return shards.blockFactory().newConstantIntVector(0, 0);
98+
}
99+
default -> {
100+
var seen = new BitSet(128);
101+
try (IntVector.Builder uniqueShardsBuilder = shards.blockFactory().newIntVectorBuilder(shards.getPositionCount())) {
102+
for (int p = 0; p < shards.getPositionCount(); p++) {
103+
int shardId = shards.getInt(p);
104+
if (seen.get(shardId) == false) {
105+
seen.set(shardId);
106+
uniqueShardsBuilder.appendInt(shardId);
107+
}
108+
}
109+
return uniqueShardsBuilder.build();
110+
}
111+
}
112+
}
84113
}
85114

86115
public DocVector(
@@ -337,33 +366,26 @@ public void closeInternal() {
337366
Releasables.closeExpectNoException(
338367
() -> blockFactory().adjustBreaker(-BASE_RAM_BYTES_USED - (shardSegmentDocMapForwards == null ? 0 : sizeOfSegmentDocMap())),
339368
shards,
369+
uniqueShards,
340370
segments,
341371
docs
342372
);
343-
forEachShardRefCounter(DecOrInc.DEC);
373+
forEachShardRefCounter(RefCounted::decRef);
344374
}
345375

346-
private enum DecOrInc {
347-
DEC,
348-
INC;
349-
350-
void apply(ShardRefCounted counters, int shardId) {
351-
switch (this) {
352-
case DEC -> counters.get(shardId).decRef();
353-
case INC -> counters.get(shardId).mustIncRef();
354-
}
355-
}
376+
private interface UpdateShardRefCounted {
377+
void apply(RefCounted rc);
356378
}
357379

358-
private void forEachShardRefCounter(DecOrInc mode) {
380+
private void forEachShardRefCounter(Consumer<RefCounted> consumer) {
359381
switch (shards) {
360-
case ConstantIntVector constantIntVector -> mode.apply(shardRefCounters, constantIntVector.getInt(0));
382+
case ConstantIntVector constantIntVector -> consumer.accept(shardRefCounters.get(constantIntVector.getInt(0)));
361383
case ConstantNullVector ignored -> {
362384
// Noop
363385
}
364386
default -> {
365-
for (int i = 0; i < shards.getPositionCount(); i++) {
366-
mode.apply(shardRefCounters, shards.getInt(i));
387+
for (int i = 0; i < uniqueShards.getPositionCount(); i++) {
388+
consumer.accept(shardRefCounters.get(uniqueShards.getInt(i)));
367389
}
368390
}
369391
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardRefCounted.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
/** Manages reference counting for {@link ShardContext}. */
1515
public interface ShardRefCounted {
16+
/**
17+
* @param shardId The shard index used by {@link org.elasticsearch.compute.data.DocVector}.
18+
* @return the {@link RefCounted} for the given shard. In production, this will almost always be a {@link ShardContext}.
19+
*/
1620
RefCounted get(int shardId);
1721

1822
static ShardRefCounted fromList(List<? extends RefCounted> refCounters) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
9292
private static final Logger logger = LogManager.getLogger(EsPhysicalOperationProviders.class);
9393

9494
/**
95-
* Context of each shard we're operating against.
95+
* Context of each shard we're operating against. Note these objects are shared across multiple operators as
96+
* {@link org.elasticsearch.core.RefCounted}.
9697
*/
9798
public abstract static class ShardContext implements org.elasticsearch.compute.lucene.ShardContext, Releasable {
9899
private final AbstractRefCounted refCounted = new AbstractRefCounted() {
@@ -379,6 +380,10 @@ public Operator.OperatorFactory timeSeriesAggregatorOperatorFactory(
379380

380381
public static class DefaultShardContext extends ShardContext {
381382
private final int index;
383+
/**
384+
* In production, this will be a {@link org.elasticsearch.search.internal.SearchContext}, but we don't want to drag that huge
385+
* dependency here.
386+
*/
382387
private final Releasable releasable;
383388
private final SearchExecutionContext ctx;
384389
private final AliasFilter aliasFilter;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ public SourceProvider createSourceProvider() {
590590
LOGGER.debug("Local execution plan:\n{}", localExecutionPlan.describe());
591591
}
592592
drivers = localExecutionPlan.createDrivers(context.sessionId());
593+
// After creating the drivers (and therefore, the operators), we can safely decrement the reference count since the operators
594+
// will hold a reference to the contexts where relevant.
593595
contexts.forEach(RefCounted::decRef);
594596
if (drivers.isEmpty()) {
595597
throw new IllegalStateException("no drivers created");

0 commit comments

Comments
 (0)