|
13 | 13 | import org.apache.lucene.search.TopDocs; |
14 | 14 | import org.apache.lucene.search.TotalHits; |
15 | 15 | import org.elasticsearch.common.breaker.CircuitBreaker; |
| 16 | +import org.elasticsearch.common.breaker.CircuitBreakingException; |
16 | 17 | import org.elasticsearch.common.breaker.NoopCircuitBreaker; |
| 18 | +import org.elasticsearch.common.io.stream.DelayableWriteable; |
| 19 | +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; |
| 20 | +import org.elasticsearch.common.io.stream.StreamOutput; |
| 21 | +import org.elasticsearch.common.io.stream.Writeable; |
17 | 22 | import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; |
| 23 | +import org.elasticsearch.common.unit.ByteSizeValue; |
18 | 24 | import org.elasticsearch.common.util.BigArrays; |
19 | 25 | import org.elasticsearch.common.util.concurrent.EsExecutors; |
20 | 26 | import org.elasticsearch.common.util.concurrent.EsExecutors.TaskTrackingConfig; |
|
25 | 31 | import org.elasticsearch.search.aggregations.AggregationBuilder; |
26 | 32 | import org.elasticsearch.search.aggregations.AggregationReduceContext; |
27 | 33 | import org.elasticsearch.search.aggregations.InternalAggregations; |
| 34 | +import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder; |
28 | 35 | import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; |
| 36 | +import org.elasticsearch.search.builder.SearchSourceBuilder; |
29 | 37 | import org.elasticsearch.search.query.QuerySearchResult; |
30 | 38 | import org.elasticsearch.test.ESTestCase; |
31 | 39 | import org.elasticsearch.threadpool.TestThreadPool; |
|
40 | 48 | import java.util.concurrent.TimeUnit; |
41 | 49 | import java.util.concurrent.atomic.AtomicInteger; |
42 | 50 | import java.util.concurrent.atomic.AtomicReference; |
| 51 | +import java.util.function.Supplier; |
43 | 52 |
|
| 53 | +import static org.hamcrest.Matchers.equalTo; |
| 54 | +import static org.hamcrest.Matchers.greaterThanOrEqualTo; |
44 | 55 | import static org.mockito.Mockito.mock; |
45 | 56 |
|
46 | 57 | public class QueryPhaseResultConsumerTests extends ESTestCase { |
@@ -148,6 +159,122 @@ public void testProgressListenerExceptionsAreCaught() throws Exception { |
148 | 159 | } |
149 | 160 | } |
150 | 161 |
|
| 162 | + public void testBatchedEstimateSizeTooBig() throws Exception { |
| 163 | + SearchRequest searchRequest = new SearchRequest("index"); |
| 164 | + searchRequest.source(new SearchSourceBuilder().aggregation(new SumAggregationBuilder("sum"))); |
| 165 | + |
| 166 | + var circuitBreakerLimit = ByteSizeValue.ofMb(256); |
| 167 | + var circuitBreaker = newLimitedBreaker(circuitBreakerLimit); |
| 168 | + // More than what the CircuitBreaker should allow |
| 169 | + long aggregationEstimatedSize = (long) (circuitBreakerLimit.getBytes() * 1.1); |
| 170 | + |
| 171 | + try ( |
| 172 | + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( |
| 173 | + searchRequest, |
| 174 | + executor, |
| 175 | + circuitBreaker, |
| 176 | + searchPhaseController, |
| 177 | + () -> false, |
| 178 | + new SearchProgressListener() { |
| 179 | + }, |
| 180 | + 10, |
| 181 | + e -> {} |
| 182 | + ) |
| 183 | + ) { |
| 184 | + var mergeResult = new QueryPhaseResultConsumer.MergeResult(List.of(), null, new DelegatingDelayableWriteable<>(() -> { |
| 185 | + fail("This shouldn't be called"); |
| 186 | + return null; |
| 187 | + }), aggregationEstimatedSize); |
| 188 | + queryPhaseResultConsumer.addBatchedPartialResult(new SearchPhaseController.TopDocsStats(0), mergeResult); |
| 189 | + |
| 190 | + try { |
| 191 | + queryPhaseResultConsumer.reduce(); |
| 192 | + fail("Expecting a circuit breaking exception to be thrown"); |
| 193 | + } catch (CircuitBreakingException e) { |
| 194 | + assertThat(e.getBytesWanted(), equalTo(aggregationEstimatedSize)); |
| 195 | + } |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + public void testBatchedEstimateSizeTooBigAfterDeserialization() throws Exception { |
| 200 | + SearchRequest searchRequest = new SearchRequest("index"); |
| 201 | + searchRequest.source(new SearchSourceBuilder().aggregation(new SumAggregationBuilder("sum"))); |
| 202 | + |
| 203 | + var circuitBreakerLimit = ByteSizeValue.ofMb(256); |
| 204 | + var circuitBreaker = newLimitedBreaker(circuitBreakerLimit); |
| 205 | + // Less than the CB, but more after the 1.5x |
| 206 | + long aggregationEstimatedSize = (long) (circuitBreakerLimit.getBytes() * 0.75); |
| 207 | + |
| 208 | + try ( |
| 209 | + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( |
| 210 | + searchRequest, |
| 211 | + executor, |
| 212 | + circuitBreaker, |
| 213 | + searchPhaseController, |
| 214 | + () -> false, |
| 215 | + new SearchProgressListener() { |
| 216 | + }, |
| 217 | + 10, |
| 218 | + e -> {} |
| 219 | + ) |
| 220 | + ) { |
| 221 | + var mergeResult = new QueryPhaseResultConsumer.MergeResult(List.of(), null, new DelegatingDelayableWriteable<>(() -> { |
| 222 | + fail("This shouldn't be called"); |
| 223 | + return null; |
| 224 | + }), aggregationEstimatedSize); |
| 225 | + queryPhaseResultConsumer.addBatchedPartialResult(new SearchPhaseController.TopDocsStats(0), mergeResult); |
| 226 | + |
| 227 | + try { |
| 228 | + queryPhaseResultConsumer.reduce(); |
| 229 | + fail("Expecting a circuit breaking exception to be thrown"); |
| 230 | + } catch (CircuitBreakingException e) { |
| 231 | + assertThat(circuitBreaker.getUsed(), greaterThanOrEqualTo(aggregationEstimatedSize)); |
| 232 | + assertThat(e.getBytesWanted(), equalTo((long) (aggregationEstimatedSize * 0.5))); |
| 233 | + } |
| 234 | + } |
| 235 | + } |
| 236 | + |
| 237 | + /** |
| 238 | + * DelayableWriteable that delegates expansion to a supplier. |
| 239 | + */ |
| 240 | + private static class DelegatingDelayableWriteable<T extends Writeable> extends DelayableWriteable<T> { |
| 241 | + private final Supplier<T> supplier; |
| 242 | + |
| 243 | + private DelegatingDelayableWriteable(Supplier<T> supplier) { |
| 244 | + this.supplier = supplier; |
| 245 | + } |
| 246 | + |
| 247 | + @Override |
| 248 | + public void writeTo(StreamOutput out) { |
| 249 | + throw new UnsupportedOperationException("Not to be called"); |
| 250 | + } |
| 251 | + |
| 252 | + @Override |
| 253 | + public T expand() { |
| 254 | + return supplier.get(); |
| 255 | + } |
| 256 | + |
| 257 | + @Override |
| 258 | + public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) { |
| 259 | + throw new UnsupportedOperationException("Not to be called"); |
| 260 | + } |
| 261 | + |
| 262 | + @Override |
| 263 | + public boolean isSerialized() { |
| 264 | + return true; |
| 265 | + } |
| 266 | + |
| 267 | + @Override |
| 268 | + public long getSerializedSize() { |
| 269 | + return 0; |
| 270 | + } |
| 271 | + |
| 272 | + @Override |
| 273 | + public void close() { |
| 274 | + // noop |
| 275 | + } |
| 276 | + } |
| 277 | + |
151 | 278 | private static class ThrowingSearchProgressListener extends SearchProgressListener { |
152 | 279 | private final AtomicInteger onQueryResult = new AtomicInteger(0); |
153 | 280 | private final AtomicInteger onPartialReduce = new AtomicInteger(0); |
|
0 commit comments