Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
347feb1
ESQL: Aggressive release of shard contexts
GalLalouche May 7, 2025
06de0d7
Update docs/changelog/129454.yaml
GalLalouche Jun 15, 2025
25fd514
Fix compilation errors
GalLalouche Jun 15, 2025
0210097
Fix bug caused by mishandling of errors during driver iteration
GalLalouche Jun 15, 2025
efc0dfd
Merge branch 'feature/shard_ref_count' of github.com:GalLalouche/elas…
GalLalouche Jun 15, 2025
f51a7c4
Change order of removal from first to last
GalLalouche Jun 16, 2025
d6ebed2
Remove printlns
GalLalouche Jun 16, 2025
8fb564a
Add ref counter to OrdinalsGroupingOperator
GalLalouche Jun 16, 2025
2843e20
Fix failing test (ManyShardsIT)
GalLalouche Jun 17, 2025
cfbf4b2
CR fixes
GalLalouche Jun 17, 2025
1ad8eb4
More test fixes
GalLalouche Jun 17, 2025
b0a35b1
Fix random DocVector generation (shard cannot be negative)
GalLalouche Jun 17, 2025
5b9b897
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 17, 2025
73ddfe7
More edge cases for shard IDs in tests
GalLalouche Jun 17, 2025
e22cc53
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 18, 2025
2c25218
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 18, 2025
e621964
CR comments
GalLalouche Jun 23, 2025
5246f3a
CR fixes
GalLalouche Jun 24, 2025
2f219c8
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 24, 2025
4761a2e
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 24, 2025
ce8462b
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 24, 2025
aab0407
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 24, 2025
6270833
Move shardRefCounter logic to the super class
GalLalouche Jun 24, 2025
97af248
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 24, 2025
de0d4bd
Fix double ref counting from previous PR
GalLalouche Jun 24, 2025
bb9a42b
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 24, 2025
69acdfd
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 24, 2025
c71ae9d
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 25, 2025
c8af2c6
Merge branch 'feature/shard_ref_count' of github.com:GalLalouche/elas…
GalLalouche Jun 25, 2025
3d28f52
Merge branch 'main' into feature/shard_ref_count
GalLalouche Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.LuceneSourceOperator;
import org.elasticsearch.compute.lucene.ShardRefCounted;
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
import org.elasticsearch.compute.operator.topn.TopNOperator;
import org.elasticsearch.core.IOUtils;
Expand Down Expand Up @@ -477,6 +478,7 @@ private void setupPages() {
pages.add(
new Page(
new DocVector(
ShardRefCounted.ALWAYS_REFERENCED,
blockFactory.newConstantIntBlockWith(0, end - begin).asVector(),
blockFactory.newConstantIntBlockWith(ctx.ord, end - begin).asVector(),
docs.build(),
Expand Down Expand Up @@ -512,7 +514,14 @@ record ItrAndOrd(PrimitiveIterator.OfInt itr, int ord) {}
if (size >= BLOCK_LENGTH) {
pages.add(
new Page(
new DocVector(blockFactory.newConstantIntVector(0, size), leafs.build(), docs.build(), null).asBlock()
new DocVector(

ShardRefCounted.ALWAYS_REFERENCED,
blockFactory.newConstantIntVector(0, size),
leafs.build(),
docs.build(),
null
).asBlock()
)
);
docs = blockFactory.newIntVectorBuilder(BLOCK_LENGTH);
Expand All @@ -525,6 +534,8 @@ record ItrAndOrd(PrimitiveIterator.OfInt itr, int ord) {}
pages.add(
new Page(
new DocVector(

ShardRefCounted.ALWAYS_REFERENCED,
blockFactory.newConstantIntBlockWith(0, size).asVector(),
leafs.build().asBlock().asVector(),
docs.build(),
Expand All @@ -551,6 +562,8 @@ record ItrAndOrd(PrimitiveIterator.OfInt itr, int ord) {}
pages.add(
new Page(
new DocVector(

ShardRefCounted.ALWAYS_REFERENCED,
blockFactory.newConstantIntVector(0, 1),
blockFactory.newConstantIntVector(next.ord, 1),
blockFactory.newConstantIntVector(next.itr.nextInt(), 1),
Expand Down
5 changes: 5 additions & 0 deletions docs/changelog/129454.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129454
summary: Aggressive release of shard contexts
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ public boolean equals(Object obj) {
}
}

/** Creates a {@link Releasable} that calls {@link RefCounted#decRef()} when closed. */
public static Releasable fromRefCounted(RefCounted refCounted) {
return () -> refCounted.decRef();
}

private static class ReleaseOnce extends AtomicReference<Releasable> implements Releasable {
ReleaseOnce(Releasable releasable) {
super(releasable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ public final void close() {
closeFuture.onResponse(null);
}

public final boolean isClosed() {
return closeFuture.isDone();
}

/**
* Should be called before executing the main query and after all other parameters have been set.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ protected SearchContext createContext(
@Override
public SearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException {
SearchContext searchContext = super.createSearchContext(request, timeout);
onPutContext.accept(searchContext.readerContext());
try {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done after confirming with @dnhatn that onPutContext here can be replaced with onCreateSearchContext. The try/catch clause was copy pasted from above.

onCreateSearchContext.accept(searchContext);
} catch (Exception e) {
searchContext.close();
throw e;
}
searchContext.addReleasable(() -> onRemoveContext.accept(searchContext.readerContext()));
return searchContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.compute.lucene.ShardRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;

Expand All @@ -17,7 +19,7 @@
/**
* Wrapper around {@link DocVector} to make a valid {@link Block}.
*/
public class DocBlock extends AbstractVectorBlock implements Block {
public class DocBlock extends AbstractVectorBlock implements Block, RefCounted {

private final DocVector vector;

Expand Down Expand Up @@ -96,6 +98,12 @@ public static class Builder implements Block.Builder {
private final IntVector.Builder shards;
private final IntVector.Builder segments;
private final IntVector.Builder docs;
private ShardRefCounted shardRefCounters = ShardRefCounted.ALWAYS_REFERENCED;

public Builder setShardRefCounted(ShardRefCounted shardRefCounters) {
this.shardRefCounters = shardRefCounters;
return this;
}

private Builder(BlockFactory blockFactory, int estimatedSize) {
IntVector.Builder shards = null;
Expand Down Expand Up @@ -183,7 +191,7 @@ public DocBlock build() {
shards = this.shards.build();
segments = this.segments.build();
docs = this.docs.build();
result = new DocVector(shards, segments, docs, null);
result = new DocVector(shardRefCounters, shards, segments, docs, null);
return result.asBlock();
} finally {
if (result == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.compute.lucene.ShardRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;

import java.util.BitSet;
import java.util.Objects;
import java.util.function.Consumer;

/**
* {@link Vector} where each entry references a lucene document.
Expand All @@ -29,6 +33,7 @@ public final class DocVector extends AbstractVector implements Vector {
public static final int SHARD_SEGMENT_DOC_MAP_PER_ROW_OVERHEAD = Integer.BYTES * 2;

private final IntVector shards;
private final IntVector uniqueShards;
private final IntVector segments;
private final IntVector docs;

Expand All @@ -48,8 +53,21 @@ public final class DocVector extends AbstractVector implements Vector {
*/
private int[] shardSegmentDocMapBackwards;

public DocVector(IntVector shards, IntVector segments, IntVector docs, Boolean singleSegmentNonDecreasing) {
private final ShardRefCounted shardRefCounters;

public ShardRefCounted shardRefCounted() {
return shardRefCounters;
}

public DocVector(
ShardRefCounted shardRefCounters,
IntVector shards,
IntVector segments,
IntVector docs,
Boolean singleSegmentNonDecreasing
) {
super(shards.getPositionCount(), shards.blockFactory());
this.shardRefCounters = shardRefCounters;
this.shards = shards;
this.segments = segments;
this.docs = docs;
Expand All @@ -64,11 +82,50 @@ public DocVector(IntVector shards, IntVector segments, IntVector docs, Boolean s
"invalid position count [" + shards.getPositionCount() + " != " + docs.getPositionCount() + "]"
);
}
blockFactory().adjustBreaker(BASE_RAM_BYTES_USED);
var uniqueShards = computeUniqueShards(shards);
try {
blockFactory().adjustBreaker(BASE_RAM_BYTES_USED);
this.uniqueShards = uniqueShards;
forEachShardRefCounter(RefCounted::mustIncRef);
} catch (Exception e) {
Releasables.close(uniqueShards);
throw e;
}
}

private static IntVector computeUniqueShards(IntVector shards) {
switch (shards) {
case ConstantIntVector constantIntVector -> {
return shards.blockFactory().newConstantIntVector(constantIntVector.getInt(0), 1);
}
case ConstantNullVector unused -> {
return shards.blockFactory().newConstantIntVector(0, 0);
}
default -> {
var seen = new BitSet(128);
try (IntVector.Builder uniqueShardsBuilder = shards.blockFactory().newIntVectorBuilder(shards.getPositionCount())) {
for (int p = 0; p < shards.getPositionCount(); p++) {
int shardId = shards.getInt(p);
if (seen.get(shardId) == false) {
seen.set(shardId);
uniqueShardsBuilder.appendInt(shardId);
}
}
return uniqueShardsBuilder.build();
}
}
}
}

public DocVector(IntVector shards, IntVector segments, IntVector docs, int[] docMapForwards, int[] docMapBackwards) {
this(shards, segments, docs, null);
public DocVector(
ShardRefCounted shardRefCounters,
IntVector shards,
IntVector segments,
IntVector docs,
int[] docMapForwards,
int[] docMapBackwards
) {
this(shardRefCounters, shards, segments, docs, null);
this.shardSegmentDocMapForwards = docMapForwards;
this.shardSegmentDocMapBackwards = docMapBackwards;
}
Expand Down Expand Up @@ -238,7 +295,7 @@ public DocVector filter(int... positions) {
filteredShards = shards.filter(positions);
filteredSegments = segments.filter(positions);
filteredDocs = docs.filter(positions);
result = new DocVector(filteredShards, filteredSegments, filteredDocs, null);
result = new DocVector(shardRefCounters, filteredShards, filteredSegments, filteredDocs, null);
return result;
} finally {
if (result == null) {
Expand Down Expand Up @@ -287,35 +344,54 @@ private static long ramBytesOrZero(int[] array) {

public static long ramBytesEstimated(
IntVector shards,
IntVector uniqueShards,
IntVector segments,
IntVector docs,
int[] shardSegmentDocMapForwards,
int[] shardSegmentDocMapBackwards
) {
return BASE_RAM_BYTES_USED + RamUsageEstimator.sizeOf(shards) + RamUsageEstimator.sizeOf(segments) + RamUsageEstimator.sizeOf(docs)
+ ramBytesOrZero(shardSegmentDocMapForwards) + ramBytesOrZero(shardSegmentDocMapBackwards);
return BASE_RAM_BYTES_USED + RamUsageEstimator.sizeOf(shards) + RamUsageEstimator.sizeOf(uniqueShards) + RamUsageEstimator.sizeOf(
segments
) + RamUsageEstimator.sizeOf(docs) + ramBytesOrZero(shardSegmentDocMapForwards) + ramBytesOrZero(shardSegmentDocMapBackwards);
}

@Override
public long ramBytesUsed() {
return ramBytesEstimated(shards, segments, docs, shardSegmentDocMapForwards, shardSegmentDocMapBackwards);
return ramBytesEstimated(shards, uniqueShards, segments, docs, shardSegmentDocMapForwards, shardSegmentDocMapBackwards);
}

@Override
public void allowPassingToDifferentDriver() {
super.allowPassingToDifferentDriver();
shards.allowPassingToDifferentDriver();
uniqueShards.allowPassingToDifferentDriver();
segments.allowPassingToDifferentDriver();
docs.allowPassingToDifferentDriver();
}

@Override
public void closeInternal() {
forEachShardRefCounter(RefCounted::decRef);
Releasables.closeExpectNoException(
() -> blockFactory().adjustBreaker(-BASE_RAM_BYTES_USED - (shardSegmentDocMapForwards == null ? 0 : sizeOfSegmentDocMap())),
shards,
uniqueShards,
segments,
docs
);
}

private void forEachShardRefCounter(Consumer<RefCounted> consumer) {
switch (shards) {
case ConstantIntVector constantIntVector -> consumer.accept(shardRefCounters.get(constantIntVector.getInt(0)));
case ConstantNullVector ignored -> {
// Noop
}
default -> {
for (int i = 0; i < uniqueShards.getPositionCount(); i++) {
consumer.accept(shardRefCounters.get(uniqueShards.getInt(i)));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasables;

import java.io.IOException;
Expand All @@ -34,12 +35,14 @@ public class LuceneCountOperator extends LuceneOperator {

private static final int PAGE_SIZE = 1;

private final List<? extends RefCounted> shardRefCounters;
private int totalHits = 0;
private int remainingDocs;

private final LeafCollector leafCollector;

public static class Factory extends LuceneOperator.Factory {
private final List<? extends RefCounted> shardRefCounters;

public Factory(
List<? extends ShardContext> contexts,
Expand All @@ -58,11 +61,12 @@ public Factory(
false,
ScoreMode.COMPLETE_NO_SCORES
);
this.shardRefCounters = contexts;
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneCountOperator(driverContext.blockFactory(), sliceQueue, limit);
return new LuceneCountOperator(shardRefCounters, driverContext.blockFactory(), sliceQueue, limit);
}

@Override
Expand All @@ -71,8 +75,15 @@ public String describe() {
}
}

public LuceneCountOperator(BlockFactory blockFactory, LuceneSliceQueue sliceQueue, int limit) {
public LuceneCountOperator(
List<? extends RefCounted> shardRefCounters,
BlockFactory blockFactory,
LuceneSliceQueue sliceQueue,
int limit
) {
super(blockFactory, PAGE_SIZE, sliceQueue);
this.shardRefCounters = shardRefCounters;
shardRefCounters.forEach(RefCounted::mustIncRef);
this.remainingDocs = limit;
this.leafCollector = new LeafCollector() {
@Override
Expand Down Expand Up @@ -171,4 +182,9 @@ protected Page getCheckedOutput() throws IOException {
protected void describe(StringBuilder sb) {
sb.append(", remainingDocs=").append(remainingDocs);
}

@Override
public void close() {
shardRefCounters.forEach(RefCounted::decRef);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ public final Page getOutput() {
protected abstract Page getCheckedOutput() throws IOException;

@Override
public void close() {}
public void close() {
if (currentScorer != null) {
currentScorer.shardContext().decRef();
}
}

LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
Expand All @@ -161,7 +165,11 @@ LuceneScorer getCurrentOrLoadNextScorer() {
) {
final Weight weight = currentSlice.weight();
processedQueries.add(weight.getQuery());
var previousScorer = currentScorer;
currentScorer = new LuceneScorer(currentSlice.shardContext(), weight, currentSlice.tags(), leaf);
if (previousScorer != null) {
previousScorer.shardContext().decRef();
}
}
assert currentScorer.maxPosition <= partialLeaf.maxDoc() : currentScorer.maxPosition + ">" + partialLeaf.maxDoc();
currentScorer.maxPosition = partialLeaf.maxDoc();
Expand All @@ -188,6 +196,7 @@ static final class LuceneScorer {
private Thread executingThread;

LuceneScorer(ShardContext shardContext, Weight weight, List<Object> tags, LeafReaderContext leafReaderContext) {
shardContext.incRef();
this.shardContext = shardContext;
this.weight = weight;
this.tags = tags;
Expand Down
Loading
Loading