Skip to content

Commit 0591cf6

Browse files
committed
CR fixes
1 parent 2843e20 commit 0591cf6

File tree

5 files changed

+49
-19
lines changed

5 files changed

+49
-19
lines changed

libs/core/src/main/java/org/elasticsearch/core/Releasables.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ public boolean equals(Object obj) {
202202
}
203203
}
204204

205+
/** Creates a {@link Releasable} that calls {@link RefCounted#decRef()} when closed. */
205206
public static Releasable fromRefCounted(RefCounted refCounted) {
206207
return () -> refCounted.decRef();
207208
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
import org.apache.lucene.util.RamUsageEstimator;
1212
import org.elasticsearch.common.unit.ByteSizeValue;
1313
import org.elasticsearch.compute.lucene.ShardRefCounted;
14+
import org.elasticsearch.core.RefCounted;
1415
import org.elasticsearch.core.ReleasableIterator;
1516
import org.elasticsearch.core.Releasables;
1617

18+
import java.util.BitSet;
1719
import java.util.Objects;
20+
import java.util.function.Consumer;
1821

1922
/**
2023
* {@link Vector} where each entry references a lucene document.
@@ -30,6 +33,7 @@ public final class DocVector extends AbstractVector implements Vector {
3033
public static final int SHARD_SEGMENT_DOC_MAP_PER_ROW_OVERHEAD = Integer.BYTES * 2;
3134

3235
private final IntVector shards;
36+
private final IntVector uniqueShards;
3337
private final IntVector segments;
3438
private final IntVector docs;
3539

@@ -80,7 +84,32 @@ public DocVector(
8084
}
8185
blockFactory().adjustBreaker(BASE_RAM_BYTES_USED);
8286

83-
forEachShardRefCounter(DecOrInc.INC);
87+
forEachShardRefCounter(RefCounted::mustIncRef);
88+
this.uniqueShards = computeUniqueShards(shards);
89+
}
90+
91+
private static IntVector computeUniqueShards(IntVector shards) {
92+
switch (shards) {
93+
case ConstantIntVector constantIntVector -> {
94+
return shards.blockFactory().newConstantIntVector(constantIntVector.getInt(0), 1);
95+
}
96+
case ConstantNullVector unused -> {
97+
return shards.blockFactory().newConstantIntVector(0, 0);
98+
}
99+
default -> {
100+
var seen = new BitSet(128);
101+
try (IntVector.Builder uniqueShardsBuilder = shards.blockFactory().newIntVectorBuilder(shards.getPositionCount())) {
102+
for (int p = 0; p < shards.getPositionCount(); p++) {
103+
int shardId = shards.getInt(p);
104+
if (seen.get(shardId) == false) {
105+
seen.set(shardId);
106+
uniqueShardsBuilder.appendInt(shardId);
107+
}
108+
}
109+
return uniqueShardsBuilder.build();
110+
}
111+
}
112+
}
84113
}
85114

86115
public DocVector(
@@ -337,33 +366,22 @@ public void closeInternal() {
337366
Releasables.closeExpectNoException(
338367
() -> blockFactory().adjustBreaker(-BASE_RAM_BYTES_USED - (shardSegmentDocMapForwards == null ? 0 : sizeOfSegmentDocMap())),
339368
shards,
369+
uniqueShards,
340370
segments,
341371
docs
342372
);
343-
forEachShardRefCounter(DecOrInc.DEC);
344-
}
345-
346-
private enum DecOrInc {
347-
DEC,
348-
INC;
349-
350-
void apply(ShardRefCounted counters, int shardId) {
351-
switch (this) {
352-
case DEC -> counters.get(shardId).decRef();
353-
case INC -> counters.get(shardId).mustIncRef();
354-
}
355-
}
373+
forEachShardRefCounter(RefCounted::decRef);
356374
}
357375

358-
private void forEachShardRefCounter(DecOrInc mode) {
376+
private void forEachShardRefCounter(Consumer<RefCounted> consumer) {
359377
switch (shards) {
360-
case ConstantIntVector constantIntVector -> mode.apply(shardRefCounters, constantIntVector.getInt(0));
378+
case ConstantIntVector constantIntVector -> consumer.accept(shardRefCounters.get(constantIntVector.getInt(0)));
361379
case ConstantNullVector ignored -> {
362380
// Noop
363381
}
364382
default -> {
365-
for (int i = 0; i < shards.getPositionCount(); i++) {
366-
mode.apply(shardRefCounters, shards.getInt(i));
383+
for (int i = 0; i < uniqueShards.getPositionCount(); i++) {
384+
consumer.accept(shardRefCounters.get(uniqueShards.getInt(i)));
367385
}
368386
}
369387
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ShardRefCounted.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
/** Manages reference counting for {@link ShardContext}. */
1515
public interface ShardRefCounted {
16+
/**
17+
* @param shardId The shard index used by {@link org.elasticsearch.compute.data.DocVector}.
18+
* @return the {@link RefCounted} for the given shard. In production, this will almost always be a {@link ShardContext}.
19+
*/
1620
RefCounted get(int shardId);
1721

1822
static ShardRefCounted fromList(List<? extends RefCounted> refCounters) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
9292
private static final Logger logger = LogManager.getLogger(EsPhysicalOperationProviders.class);
9393

9494
/**
95-
* Context of each shard we're operating against.
95+
* Context of each shard we're operating against. Note these objects are shared across multiple operators as
96+
* {@link org.elasticsearch.core.RefCounted}.
9697
*/
9798
public abstract static class ShardContext implements org.elasticsearch.compute.lucene.ShardContext, Releasable {
9899
private final AbstractRefCounted refCounted = new AbstractRefCounted() {
@@ -379,6 +380,10 @@ public Operator.OperatorFactory timeSeriesAggregatorOperatorFactory(
379380

380381
public static class DefaultShardContext extends ShardContext {
381382
private final int index;
383+
/**
384+
* In production, this will be a {@link org.elasticsearch.search.internal.SearchContext}, but we don't want to drag that huge
385+
* dependency here.
386+
*/
382387
private final Releasable releasable;
383388
private final SearchExecutionContext ctx;
384389
private final AliasFilter aliasFilter;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ public SourceProvider createSourceProvider() {
590590
LOGGER.debug("Local execution plan:\n{}", localExecutionPlan.describe());
591591
}
592592
drivers = localExecutionPlan.createDrivers(context.sessionId());
593+
// After creating the drivers (and therefore, the operators), we can safely decrement the reference count since the operators
594+
// will hold a reference to the contexts where relevant.
593595
contexts.forEach(RefCounted::decRef);
594596
if (drivers.isEmpty()) {
595597
throw new IllegalStateException("no drivers created");

0 commit comments

Comments
 (0)