Skip to content

Commit 32588b5

Browse files
committed
Randomize groupIds block types to check most AddInput cases
1 parent d95a81b commit 32588b5

File tree

3 files changed

+248
-1
lines changed

3 files changed

+248
-1
lines changed

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import org.elasticsearch.common.util.BitArray;
1313
import org.elasticsearch.compute.ConstantBooleanExpressionEvaluator;
1414
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
15+
import org.elasticsearch.compute.aggregation.blockhash.WrappedBlockHash;
1516
import org.elasticsearch.compute.data.Block;
1617
import org.elasticsearch.compute.data.BlockFactory;
18+
import org.elasticsearch.compute.data.BlockTypeRandomizer;
1719
import org.elasticsearch.compute.data.BooleanBlock;
1820
import org.elasticsearch.compute.data.BytesRefBlock;
1921
import org.elasticsearch.compute.data.DoubleBlock;
@@ -39,6 +41,7 @@
3941
import org.elasticsearch.compute.test.TestBlockFactory;
4042
import org.elasticsearch.core.Nullable;
4143
import org.elasticsearch.core.Releasables;
44+
import org.elasticsearch.index.analysis.AnalysisRegistry;
4245
import org.elasticsearch.xpack.esql.core.type.DataType;
4346
import org.hamcrest.Matcher;
4447

@@ -47,6 +50,7 @@
4750
import java.util.SortedSet;
4851
import java.util.TreeSet;
4952
import java.util.function.Function;
53+
import java.util.function.Supplier;
5054
import java.util.stream.DoubleStream;
5155
import java.util.stream.IntStream;
5256
import java.util.stream.LongStream;
@@ -104,7 +108,7 @@ private Operator.OperatorFactory simpleWithMode(
104108
if (randomBoolean()) {
105109
supplier = chunkGroups(emitChunkSize, supplier);
106110
}
107-
return new HashAggregationOperator.HashAggregationOperatorFactory(
111+
return new RandomizingHashAggregationOperatorFactory(
108112
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
109113
mode,
110114
List.of(supplier.groupingAggregatorFactory(mode, channels(mode))),
@@ -777,4 +781,76 @@ public String describe() {
777781
};
778782
}
779783

784+
private record RandomizingHashAggregationOperatorFactory(
785+
List<BlockHash.GroupSpec> groups,
786+
AggregatorMode aggregatorMode,
787+
List<GroupingAggregator.Factory> aggregators,
788+
int maxPageSize,
789+
AnalysisRegistry analysisRegistry
790+
) implements Operator.OperatorFactory {
791+
792+
@Override
793+
public Operator get(DriverContext driverContext) {
794+
Supplier<BlockHash> blockHashSupplier = () -> {
795+
BlockHash blockHash = groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)
796+
? BlockHash.buildCategorizeBlockHash(
797+
groups,
798+
aggregatorMode,
799+
driverContext.blockFactory(),
800+
analysisRegistry,
801+
maxPageSize
802+
)
803+
: BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false);
804+
805+
return new WrappedBlockHash(driverContext.blockFactory(), blockHash) {
806+
@Override
807+
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
808+
blockHash.add(page, new GroupingAggregatorFunction.AddInput() {
809+
@Override
810+
public void add(int positionOffset, IntBlock groupIds) {
811+
IntBlock newGroupIds = aggregatorMode.isInputPartial()
812+
? groupIds
813+
: BlockTypeRandomizer.randomizeBlockType(groupIds);
814+
addInput.add(positionOffset, newGroupIds);
815+
}
816+
817+
@Override
818+
public void add(int positionOffset, IntArrayBlock groupIds) {
819+
add(positionOffset, (IntBlock) groupIds);
820+
}
821+
822+
@Override
823+
public void add(int positionOffset, IntBigArrayBlock groupIds) {
824+
add(positionOffset, (IntBlock) groupIds);
825+
}
826+
827+
@Override
828+
public void add(int positionOffset, IntVector groupIds) {
829+
add(positionOffset, groupIds.asBlock());
830+
}
831+
832+
@Override
833+
public void close() {
834+
addInput.close();
835+
}
836+
});
837+
}
838+
};
839+
};
840+
841+
return new HashAggregationOperator(aggregators, blockHashSupplier, driverContext);
842+
}
843+
844+
@Override
845+
public String describe() {
846+
return new HashAggregationOperator.HashAggregationOperatorFactory(
847+
groups,
848+
aggregatorMode,
849+
aggregators,
850+
maxPageSize,
851+
analysisRegistry
852+
).describe();
853+
}
854+
}
855+
780856
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.elasticsearch.common.unit.ByteSizeValue;
11+
import org.elasticsearch.common.util.BigArrays;
12+
import org.elasticsearch.common.util.BitArray;
13+
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
14+
import org.elasticsearch.compute.data.Block;
15+
import org.elasticsearch.compute.data.BlockFactory;
16+
import org.elasticsearch.compute.data.IntBlock;
17+
import org.elasticsearch.compute.data.IntVector;
18+
import org.elasticsearch.compute.data.Page;
19+
import org.elasticsearch.core.ReleasableIterator;
20+
21+
/**
22+
* A test BlockHash that wraps another one.
23+
* <p>
24+
* Its methods can be overridden to implement custom behaviours or checks.
25+
* </p>
26+
*/
27+
public class WrappedBlockHash extends BlockHash {
28+
protected BlockHash blockHash;
29+
30+
public WrappedBlockHash(BlockFactory blockFactory, BlockHash blockHash) {
31+
super(blockFactory);
32+
this.blockHash = blockHash;
33+
}
34+
35+
@Override
36+
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
37+
blockHash.add(page, addInput);
38+
}
39+
40+
@Override
41+
public ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
42+
return blockHash.lookup(page, targetBlockSize);
43+
}
44+
45+
@Override
46+
public Block[] getKeys() {
47+
return blockHash.getKeys();
48+
}
49+
50+
@Override
51+
public IntVector nonEmpty() {
52+
return blockHash.nonEmpty();
53+
}
54+
55+
@Override
56+
public BitArray seenGroupIds(BigArrays bigArrays) {
57+
return blockHash.seenGroupIds(bigArrays);
58+
}
59+
60+
@Override
61+
public void close() {
62+
blockHash.close();
63+
}
64+
65+
@Override
66+
public String toString() {
67+
return blockHash.toString();
68+
}
69+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.data;
9+
10+
import org.elasticsearch.compute.test.TestBlockFactory;
11+
12+
import java.util.BitSet;
13+
14+
import static org.elasticsearch.test.ESTestCase.randomIntBetween;
15+
16+
public class BlockTypeRandomizer {
17+
private BlockTypeRandomizer() {}
18+
19+
/**
20+
* Returns a block with the same contents, but with a randomized type (Constant, vector, big-array...).
21+
* <p>
22+
* The new block uses a non-breaking block builder, and doesn't increment the circuit breaking.
23+
* This is done to avoid randomly using more memory in tests that expect a deterministic memory usage.
24+
* </p>
25+
*/
26+
public static IntBlock randomizeBlockType(IntBlock block) {
27+
// Just to track the randomization
28+
int classCount = 4;
29+
30+
BlockFactory blockFactory = TestBlockFactory.getNonBreakingInstance();
31+
32+
//
33+
// ConstantNullBlock. It requires all positions to be null
34+
//
35+
if (randomIntBetween(0, --classCount) == 0 && block.areAllValuesNull()) {
36+
if (block instanceof ConstantNullBlock) {
37+
return block;
38+
}
39+
return new ConstantNullBlock(block.getPositionCount(), blockFactory);
40+
}
41+
42+
//
43+
// IntVectorBlock. It doesn't allow nulls or multivalues
44+
//
45+
if (randomIntBetween(0, --classCount) == 0 && block.doesHaveMultivaluedFields() == false && block.mayHaveNulls() == false) {
46+
if (block instanceof IntVectorBlock) {
47+
return block;
48+
}
49+
50+
int[] values = new int[block.getPositionCount()];
51+
for (int i = 0; i < values.length; i++) {
52+
values[i] = block.getInt(i);
53+
}
54+
55+
return new IntVectorBlock(new IntArrayVector(values, block.getPositionCount(), blockFactory));
56+
}
57+
58+
// Both IntArrayBlock and IntBigArrayBlock need a nullsBitSet and a firstValueIndexes int[]
59+
int[] firstValueIndexes = new int[block.getPositionCount() + 1];
60+
BitSet nullsMask = new BitSet(block.getPositionCount());
61+
for (int i = 0; i < block.getPositionCount(); i++) {
62+
firstValueIndexes[i] = block.getFirstValueIndex(i);
63+
64+
if (block.isNull(i)) {
65+
nullsMask.set(i);
66+
}
67+
}
68+
int totalValues = block.getFirstValueIndex(block.getPositionCount() - 1) + block.getValueCount(block.getPositionCount() - 1);
69+
firstValueIndexes[firstValueIndexes.length - 1] = totalValues;
70+
71+
//
72+
// IntArrayBlock
73+
//
74+
if (randomIntBetween(0, --classCount) == 0) {
75+
if (block instanceof IntVectorBlock) {
76+
return block;
77+
}
78+
79+
int[] values = new int[totalValues];
80+
for (int i = 0; i < values.length; i++) {
81+
values[i] = block.getInt(i);
82+
}
83+
84+
return new IntArrayBlock(values, block.getPositionCount(), firstValueIndexes, nullsMask, block.mvOrdering(), blockFactory);
85+
}
86+
assert classCount == 1;
87+
88+
//
89+
// IntBigArrayBlock
90+
//
91+
if (block instanceof IntBigArrayBlock) {
92+
return block;
93+
}
94+
95+
var intArray = blockFactory.bigArrays().newIntArray(totalValues);
96+
for (int i = 0; i < block.getPositionCount(); i++) {
97+
intArray.set(i, block.getInt(i));
98+
}
99+
100+
return new IntBigArrayBlock(intArray, block.getPositionCount(), firstValueIndexes, nullsMask, block.mvOrdering(), blockFactory);
101+
}
102+
}

0 commit comments

Comments
 (0)