Skip to content

Commit 0631be5

Browse files
authored
Account for DelayedBucket before reduction (elastic#113013) (elastic#113458)
This commit moves the account for the DelayableBucket before reduction, therefore in some adversarial cases, we should exit much sooner.
1 parent f8dbda3 commit 0631be5

File tree

5 files changed

+70
-9
lines changed

5 files changed

+70
-9
lines changed

docs/changelog/113013.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 113013
2+
summary: Account for `DelayedBucket` before reduction
3+
area: Aggregations
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/search/aggregations/DelayedBucket.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
/**
1616
* A wrapper around reducing buckets with the same key that can delay that reduction
1717
* as long as possible. It's stateful and not even close to thread safe.
18+
* <p>
19+
* It is responsibility of the caller to account for buckets created using DelayedBucket.
20+
* It should call {@link #nonCompetitive} to release any possible sub-bucket creation if
21+
* a bucket is rejected from the final response.
1822
*/
1923
public final class DelayedBucket<B extends InternalMultiBucketAggregation.InternalBucket> {
2024
/**
@@ -45,7 +49,6 @@ public DelayedBucket(List<B> toReduce) {
4549
*/
4650
public B reduced(BiFunction<List<B>, AggregationReduceContext, B> reduce, AggregationReduceContext reduceContext) {
4751
if (reduced == null) {
48-
reduceContext.consumeBucketsAndMaybeBreak(1);
4952
reduced = reduce.apply(toReduce, reduceContext);
5053
toReduce = null;
5154
}
@@ -95,8 +98,8 @@ public String toString() {
9598
*/
9699
void nonCompetitive(AggregationReduceContext reduceContext) {
97100
if (reduced != null) {
98-
// -1 for itself, -countInnerBucket for all the sub-buckets.
99-
reduceContext.consumeBucketsAndMaybeBreak(-1 - InternalMultiBucketAggregation.countInnerBucket(reduced));
101+
// -countInnerBucket for all the sub-buckets.
102+
reduceContext.consumeBucketsAndMaybeBreak(-InternalMultiBucketAggregation.countInnerBucket(reduced));
100103
}
101104
}
102105
}

server/src/main/java/org/elasticsearch/search/aggregations/TopBucketBuilder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ public void add(DelayedBucket<B> bucket) {
132132
DelayedBucket<B> removed = queue.insertWithOverflow(bucket);
133133
if (removed != null) {
134134
nonCompetitive.accept(removed);
135+
// release any created sub-buckets
135136
removed.nonCompetitive(reduceContext);
137+
} else {
138+
// add one bucket to the final result
139+
reduceContext.consumeBucketsAndMaybeBreak(1);
136140
}
137141
}
138142

@@ -183,6 +187,8 @@ public void add(DelayedBucket<B> bucket) {
183187
next.add(bucket);
184188
return;
185189
}
190+
// add one bucket to the final result
191+
reduceContext.consumeBucketsAndMaybeBreak(1);
186192
buffer.add(bucket);
187193
if (buffer.size() < size) {
188194
return;

server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/AbstractInternalTerms.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ public InternalAggregation get() {
290290
result = new ArrayList<>();
291291
thisReduceOrder = reduceBuckets(bucketsList, getThisReduceOrder(), bucket -> {
292292
if (result.size() < getRequiredSize()) {
293+
reduceContext.consumeBucketsAndMaybeBreak(1);
293294
result.add(bucket.reduced(AbstractInternalTerms.this::reduceBucket, reduceContext));
294295
} else {
295296
otherDocCount[0] += bucket.getDocCount();
@@ -311,11 +312,10 @@ public InternalAggregation get() {
311312
result = top.build();
312313
} else {
313314
result = new ArrayList<>();
314-
thisReduceOrder = reduceBuckets(
315-
bucketsList,
316-
getThisReduceOrder(),
317-
bucket -> result.add(bucket.reduced(AbstractInternalTerms.this::reduceBucket, reduceContext))
318-
);
315+
thisReduceOrder = reduceBuckets(bucketsList, getThisReduceOrder(), bucket -> {
316+
reduceContext.consumeBucketsAndMaybeBreak(1);
317+
result.add(bucket.reduced(AbstractInternalTerms.this::reduceBucket, reduceContext));
318+
});
319319
}
320320
for (B r : result) {
321321
if (sumDocCountError == -1) {

server/src/test/java/org/elasticsearch/search/aggregations/DelayedBucketTests.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import static org.hamcrest.Matchers.greaterThan;
2626
import static org.hamcrest.Matchers.lessThan;
2727
import static org.hamcrest.Matchers.sameInstance;
28+
import static org.mockito.Mockito.mock;
29+
import static org.mockito.Mockito.when;
2830

2931
public class DelayedBucketTests extends ESTestCase {
3032
public void testToString() {
@@ -40,6 +42,23 @@ public void testReduced() {
4042
assertThat(b.reduced(reduce, context), sameInstance(b.reduced(reduce, context)));
4143
assertThat(b.reduced(reduce, context).getKeyAsString(), equalTo("test"));
4244
assertThat(b.reduced(reduce, context).getDocCount(), equalTo(3L));
45+
// it only accounts for sub-buckets
46+
assertEquals(0, buckets.get());
47+
}
48+
49+
public void testReducedSubAggregation() {
50+
AtomicInteger buckets = new AtomicInteger();
51+
AggregationReduceContext context = new AggregationReduceContext.ForFinal(null, null, () -> false, null, buckets::addAndGet);
52+
BiFunction<List<InternalBucket>, AggregationReduceContext, InternalBucket> reduce = mockReduce(context);
53+
DelayedBucket<InternalBucket> b = new DelayedBucket<>(
54+
List.of(bucket("test", 1, mockMultiBucketAgg()), bucket("test", 2, mockMultiBucketAgg()))
55+
);
56+
57+
assertThat(b.getDocCount(), equalTo(3L));
58+
assertThat(b.reduced(reduce, context), sameInstance(b.reduced(reduce, context)));
59+
assertThat(b.reduced(reduce, context).getKeyAsString(), equalTo("test"));
60+
assertThat(b.reduced(reduce, context).getDocCount(), equalTo(3L));
61+
// it only accounts for sub-buckets
4362
assertEquals(1, buckets.get());
4463
}
4564

@@ -76,6 +95,19 @@ public void testNonCompetitiveReduced() {
7695
BiFunction<List<InternalBucket>, AggregationReduceContext, InternalBucket> reduce = mockReduce(context);
7796
DelayedBucket<InternalBucket> b = new DelayedBucket<>(List.of(bucket("test", 1)));
7897
b.reduced(reduce, context);
98+
// only account for sub-aggregations
99+
assertEquals(0, buckets.get());
100+
b.nonCompetitive(context);
101+
assertEquals(0, buckets.get());
102+
}
103+
104+
public void testNonCompetitiveReducedSubAggregation() {
105+
AtomicInteger buckets = new AtomicInteger();
106+
AggregationReduceContext context = new AggregationReduceContext.ForFinal(null, null, () -> false, null, buckets::addAndGet);
107+
BiFunction<List<InternalBucket>, AggregationReduceContext, InternalBucket> reduce = mockReduce(context);
108+
DelayedBucket<InternalBucket> b = new DelayedBucket<>(List.of(bucket("test", 1, mockMultiBucketAgg())));
109+
b.reduced(reduce, context);
110+
// only account for sub-aggregations
79111
assertEquals(1, buckets.get());
80112
b.nonCompetitive(context);
81113
assertEquals(0, buckets.get());
@@ -85,10 +117,25 @@ private static InternalBucket bucket(String key, long docCount) {
85117
return new StringTerms.Bucket(new BytesRef(key), docCount, InternalAggregations.EMPTY, false, 0, DocValueFormat.RAW);
86118
}
87119

120+
private static InternalBucket bucket(String key, long docCount, InternalAggregations subAggregations) {
121+
return new StringTerms.Bucket(new BytesRef(key), docCount, subAggregations, false, 0, DocValueFormat.RAW);
122+
}
123+
88124
static BiFunction<List<InternalBucket>, AggregationReduceContext, InternalBucket> mockReduce(AggregationReduceContext context) {
89125
return (l, c) -> {
90126
assertThat(c, sameInstance(context));
91-
return bucket(l.get(0).getKeyAsString(), l.stream().mapToLong(Bucket::getDocCount).sum());
127+
context.consumeBucketsAndMaybeBreak(l.get(0).getAggregations().asList().size());
128+
return bucket(l.get(0).getKeyAsString(), l.stream().mapToLong(Bucket::getDocCount).sum(), l.get(0).getAggregations());
92129
};
93130
}
131+
132+
@SuppressWarnings("unchecked")
133+
private InternalAggregations mockMultiBucketAgg() {
134+
List<InternalBucket> buckets = List.of(bucket("sub", 1));
135+
InternalMultiBucketAggregation<?, InternalBucket> mock = (InternalMultiBucketAggregation<?, InternalBucket>) mock(
136+
InternalMultiBucketAggregation.class
137+
);
138+
when(mock.getBuckets()).thenReturn(buckets);
139+
return InternalAggregations.from(List.of(mock));
140+
}
94141
}

0 commit comments

Comments
 (0)