Skip to content

Commit 4848184

Browse files
committed
Add "emitEmptyBuckets" parameter to the "Bucket" function.
1 parent a692cbd commit 4848184

File tree

20 files changed

+667
-56
lines changed

20 files changed

+667
-56
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,12 @@ private static Operator operator(DriverContext driverContext, String grouping, S
191191
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
192192
);
193193
case TOP_N_LONGS -> List.of(
194-
new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT))
194+
new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT), null)
195195
);
196196
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
197197
};
198198
return new HashAggregationOperator(
199+
groups,
199200
List.of(supplier(op, dataType, filter).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(groups.size()))),
200201
() -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false),
201202
driverContext

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ private static Operator operator(DriverContext driverContext, int groups, String
122122
}
123123
List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
124124
return new HashAggregationOperator(
125+
groupSpec,
125126
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
126127
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
127128
driverContext

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ static TransportVersion def(int id) {
345345
public static final TransportVersion PROJECT_STATE_REGISTRY_ENTRY = def(9_124_0_00);
346346
public static final TransportVersion ML_INFERENCE_LLAMA_ADDED = def(9_125_0_00);
347347
public static final TransportVersion SHARD_WRITE_LOAD_IN_CLUSTER_INFO = def(9_126_0_00);
348+
public static final TransportVersion ESQL_EMIT_EMPTY_BUCKETS = def(9_127_0_00);
348349

349350
/*
350351
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ public class GroupingAggregator implements Releasable {
2323

2424
private final AggregatorMode mode;
2525

26+
public AggregatorMode getMode() {
27+
return mode;
28+
}
29+
2630
public interface Factory extends Function<DriverContext, GroupingAggregator>, Describable {}
2731

2832
public GroupingAggregator(GroupingAggregatorFunction aggregatorFunction, AggregatorMode mode) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
127127
*/
128128
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}
129129

130+
public interface EmptyBucketGenerator {
131+
int getEmptyBucketCount();
132+
133+
void generate(Block.Builder blockBuilder);
134+
}
135+
130136
/**
131137
* Configuration for a BlockHash group spec that is doing text categorization.
132138
*/
@@ -137,13 +143,19 @@ public enum OutputFormat {
137143
}
138144
}
139145

140-
public record GroupSpec(int channel, ElementType elementType, @Nullable CategorizeDef categorizeDef, @Nullable TopNDef topNDef) {
146+
public record GroupSpec(
147+
int channel,
148+
ElementType elementType,
149+
@Nullable CategorizeDef categorizeDef,
150+
@Nullable TopNDef topNDef,
151+
@Nullable EmptyBucketGenerator emptyBucketGenerator
152+
) {
141153
public GroupSpec(int channel, ElementType elementType) {
142-
this(channel, elementType, null, null);
154+
this(channel, elementType, null, null, null);
143155
}
144156

145157
public GroupSpec(int channel, ElementType elementType, CategorizeDef categorizeDef) {
146-
this(channel, elementType, categorizeDef, null);
158+
this(channel, elementType, categorizeDef, null, null);
147159
}
148160

149161
public boolean isCategorize() {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
2121
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
2222
import org.elasticsearch.compute.data.Block;
23+
import org.elasticsearch.compute.data.DocBlock;
2324
import org.elasticsearch.compute.data.IntArrayBlock;
2425
import org.elasticsearch.compute.data.IntBigArrayBlock;
2526
import org.elasticsearch.compute.data.IntVector;
@@ -34,6 +35,7 @@
3435
import java.util.Arrays;
3536
import java.util.List;
3637
import java.util.Objects;
38+
import java.util.concurrent.atomic.AtomicBoolean;
3739
import java.util.function.Supplier;
3840

3941
import static java.util.Objects.requireNonNull;
@@ -52,6 +54,7 @@ public record HashAggregationOperatorFactory(
5254
public Operator get(DriverContext driverContext) {
5355
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
5456
return new HashAggregationOperator(
57+
groups,
5558
aggregators,
5659
() -> BlockHash.buildCategorizeBlockHash(
5760
groups,
@@ -64,6 +67,7 @@ public Operator get(DriverContext driverContext) {
6467
);
6568
}
6669
return new HashAggregationOperator(
70+
groups,
6771
aggregators,
6872
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
6973
driverContext
@@ -83,6 +87,7 @@ public String describe() {
8387
private boolean finished;
8488
private Page output;
8589

90+
private final List<BlockHash.GroupSpec> groups;
8691
private final BlockHash blockHash;
8792

8893
protected final List<GroupingAggregator> aggregators;
@@ -117,10 +122,12 @@ public String describe() {
117122

118123
@SuppressWarnings("this-escape")
119124
public HashAggregationOperator(
125+
List<BlockHash.GroupSpec> groups,
120126
List<GroupingAggregator.Factory> aggregators,
121127
Supplier<BlockHash> blockHash,
122128
DriverContext driverContext
123129
) {
130+
this.groups = groups;
124131
this.aggregators = new ArrayList<>(aggregators.size());
125132
this.driverContext = driverContext;
126133
boolean success = false;
@@ -142,8 +149,22 @@ public boolean needsInput() {
142149
return finished == false;
143150
}
144151

152+
private final AtomicBoolean isInitialPage = new AtomicBoolean(true);
153+
145154
@Override
146155
public void addInput(Page page) {
156+
if (isInitialPage.compareAndSet(true, false)
157+
&& (aggregators.size() == 0 || AggregatorMode.INITIAL.equals(aggregators.get(0).getMode()))) {
158+
Page initialPage = createInitialPage(page);
159+
if (initialPage != null) {
160+
addInputInternal(initialPage);
161+
return;
162+
}
163+
}
164+
addInputInternal(page);
165+
}
166+
167+
private void addInputInternal(Page page) {
147168
try {
148169
GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()];
149170
class AddInput implements GroupingAggregatorFunction.AddInput {
@@ -289,6 +310,42 @@ protected Page wrapPage(Page page) {
289310
return page;
290311
}
291312

313+
private Page createInitialPage(Page page) {
314+
// If no groups are generating bucket keys, move on
315+
if (groups.stream().allMatch(g -> g.emptyBucketGenerator() == null)) {
316+
return page;
317+
}
318+
Block.Builder[] blockBuilders = new Block.Builder[page.getBlockCount()];
319+
for (int channel = 0; channel < page.getBlockCount(); channel++) {
320+
Block block = page.getBlock(channel);
321+
blockBuilders[channel] = block.elementType().newBlockBuilder(block.getPositionCount(), driverContext.blockFactory());
322+
blockBuilders[channel].copyFrom(block, 0, block.getPositionCount());
323+
}
324+
for (BlockHash.GroupSpec group : groups) {
325+
BlockHash.EmptyBucketGenerator emptyBucketGenerator = group.emptyBucketGenerator();
326+
if (emptyBucketGenerator != null) {
327+
for (int channel = 0; channel < page.getBlockCount(); channel++) {
328+
if (group.channel() == channel) {
329+
emptyBucketGenerator.generate(blockBuilders[channel]);
330+
} else {
331+
for (int i = 0; i < emptyBucketGenerator.getEmptyBucketCount(); i++) {
332+
if (page.getBlock(channel) instanceof DocBlock) {
333+
// TODO: DocBlock doesn't allow appending nulls
334+
((DocBlock.Builder) blockBuilders[channel]).appendShard(0).appendSegment(0).appendDoc(0);
335+
} else {
336+
blockBuilders[channel].appendNull();
337+
}
338+
}
339+
}
340+
}
341+
}
342+
}
343+
Block[] blocks = Arrays.stream(blockBuilders).map(Block.Builder::build).toArray(Block[]::new);
344+
Releasables.closeExpectNoException(blockBuilders);
345+
page.releaseBlocks();
346+
return new Page(blocks);
347+
}
348+
292349
@Override
293350
public String toString() {
294351
StringBuilder sb = new StringBuilder();

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public record Factory(
4040
@Override
4141
public Operator get(DriverContext driverContext) {
4242
// TODO: use TimeSeriesBlockHash when possible
43-
return new TimeSeriesAggregationOperator(timeBucket, aggregators, () -> {
43+
return new TimeSeriesAggregationOperator(timeBucket, groups, aggregators, () -> {
4444
if (sortedInput && groups.size() == 2) {
4545
return new TimeSeriesBlockHash(groups.get(0).channel(), groups.get(1).channel(), driverContext.blockFactory());
4646
} else {
@@ -68,11 +68,12 @@ public String describe() {
6868

6969
public TimeSeriesAggregationOperator(
7070
Rounding.Prepared timeBucket,
71+
List<BlockHash.GroupSpec> groups,
7172
List<GroupingAggregator.Factory> aggregators,
7273
Supplier<BlockHash> blockHash,
7374
DriverContext driverContext
7475
) {
75-
super(aggregators, blockHash, driverContext);
76+
super(groups, aggregators, blockHash, driverContext);
7677
this.timeBucket = timeBucket;
7778
}
7879

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ public void close() {
910910
};
911911
};
912912

913-
return new HashAggregationOperator(aggregators, blockHashSupplier, driverContext);
913+
return new HashAggregationOperator(groups, aggregators, blockHashSupplier, driverContext);
914914
}
915915

916916
@Override

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ private void hashBatchesCallbackOnLast(Consumer<OrdsAndKeys> callback, Block[]..
363363
private BlockHash buildBlockHash(int emitBatchSize, Block... values) {
364364
List<BlockHash.GroupSpec> specs = new ArrayList<>(values.length);
365365
for (int c = 0; c < values.length; c++) {
366-
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), null, topNDef(c)));
366+
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), null, topNDef(c), null));
367367
}
368368
assert forcePackedHash == false : "Packed TopN hash not implemented yet";
369369
/*return forcePackedHash

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public void testTopNNullsLast() {
114114

115115
try (
116116
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
117-
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3))),
117+
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3), null)),
118118
mode,
119119
List.of(
120120
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
@@ -191,7 +191,7 @@ public void testTopNNullsFirst() {
191191

192192
try (
193193
var operator = new HashAggregationOperator.HashAggregationOperatorFactory(
194-
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, true, 3))),
194+
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, true, 3), null)),
195195
mode,
196196
List.of(
197197
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels),
@@ -277,7 +277,7 @@ public void testTopNNullsIntermediateDiscards() {
277277
var maxAggregatorChannels = mode.isInputPartial() ? List.of(3, 4) : List.of(1);
278278

279279
return new HashAggregationOperator.HashAggregationOperatorFactory(
280-
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3))),
280+
List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3), null)),
281281
mode,
282282
List.of(
283283
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, sumAggregatorChannels),

0 commit comments

Comments
 (0)