Skip to content

Commit 8dd15ac

Browse files
committed
randomize CategorizePackedValuesBlockHashTests
1 parent a9c3ed7 commit 8dd15ac

File tree

2 files changed

+71
-40
lines changed

2 files changed

+71
-40
lines changed

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,6 @@ public void close() {
130130
} finally {
131131
page.releaseBlocks();
132132
}
133-
134-
// TODO: randomize values? May give wrong results
135-
// TODO: assert the categorizer state after adding pages.
136133
}
137134

138135
public void testCategorizeRawMultivalue() {

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
2121
import org.elasticsearch.compute.data.Block;
2222
import org.elasticsearch.compute.data.BlockFactory;
23+
import org.elasticsearch.compute.data.BlockUtils;
2324
import org.elasticsearch.compute.data.BytesRefBlock;
24-
import org.elasticsearch.compute.data.BytesRefVector;
2525
import org.elasticsearch.compute.data.ElementType;
2626
import org.elasticsearch.compute.data.IntBlock;
27-
import org.elasticsearch.compute.data.IntVector;
2827
import org.elasticsearch.compute.data.Page;
2928
import org.elasticsearch.compute.operator.CannedSourceOperator;
3029
import org.elasticsearch.compute.operator.Driver;
@@ -72,6 +71,8 @@ public void testCategorize_withDriver() {
7271
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
7372
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
7473
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
74+
boolean withNull = randomBoolean();
75+
boolean withMultivalues = randomBoolean();
7576

7677
List<BlockHash.GroupSpec> groupSpecs = List.of(
7778
new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true),
@@ -80,28 +81,42 @@ public void testCategorize_withDriver() {
8081

8182
LocalSourceOperator.BlockSupplier input1 = () -> {
8283
try (
83-
BytesRefVector.Builder messagesBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
84-
IntVector.Builder idsBuilder = driverContext.blockFactory().newIntVectorBuilder(10)
84+
BytesRefBlock.Builder messagesBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(10);
85+
IntBlock.Builder idsBuilder = driverContext.blockFactory().newIntBlockBuilder(10)
8586
) {
87+
if (withMultivalues) {
88+
messagesBuilder.beginPositionEntry();
89+
}
8690
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.1"));
8791
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.2"));
92+
if (withMultivalues) {
93+
messagesBuilder.endPositionEntry();
94+
}
95+
idsBuilder.appendInt(7);
96+
if (withMultivalues == false) {
97+
idsBuilder.appendInt(7);
98+
}
99+
88100
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.3"));
89101
messagesBuilder.appendBytesRef(new BytesRef("connection error"));
90102
messagesBuilder.appendBytesRef(new BytesRef("connection error"));
91103
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.4"));
92-
idsBuilder.appendInt(7);
93-
idsBuilder.appendInt(7);
94104
idsBuilder.appendInt(42);
95105
idsBuilder.appendInt(7);
96106
idsBuilder.appendInt(42);
97107
idsBuilder.appendInt(7);
98-
return new Block[] { messagesBuilder.build().asBlock(), idsBuilder.build().asBlock() };
108+
109+
if (withNull) {
110+
messagesBuilder.appendNull();
111+
idsBuilder.appendInt(43);
112+
}
113+
return new Block[] { messagesBuilder.build(), idsBuilder.build() };
99114
}
100115
};
101116
LocalSourceOperator.BlockSupplier input2 = () -> {
102117
try (
103-
BytesRefVector.Builder messagesBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
104-
IntVector.Builder idsBuilder = driverContext.blockFactory().newIntVectorBuilder(10)
118+
BytesRefBlock.Builder messagesBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(10);
119+
IntBlock.Builder idsBuilder = driverContext.blockFactory().newIntBlockBuilder(10)
105120
) {
106121
messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.1"));
107122
messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.2"));
@@ -111,7 +126,11 @@ public void testCategorize_withDriver() {
111126
idsBuilder.appendInt(7);
112127
idsBuilder.appendInt(7);
113128
idsBuilder.appendInt(42);
114-
return new Block[] { messagesBuilder.build().asBlock(), idsBuilder.build().asBlock() };
129+
if (withNull) {
130+
messagesBuilder.appendNull();
131+
idsBuilder.appendNull();
132+
}
133+
return new Block[] { messagesBuilder.build(), idsBuilder.build() };
115134
}
116135
};
117136

@@ -177,38 +196,53 @@ public void testCategorize_withDriver() {
177196
BytesRefBlock outputValues = finalOutput.get(0).getBlock(2);
178197
assertThat(outputIds.getPositionCount(), equalTo(outputMessages.getPositionCount()));
179198
assertThat(outputValues.getPositionCount(), equalTo(outputMessages.getPositionCount()));
180-
Map<String, Map<Integer, Set<String>>> values = new HashMap<>();
199+
Map<String, Map<Integer, Set<String>>> result = new HashMap<>();
181200
for (int i = 0; i < outputMessages.getPositionCount(); i++) {
182-
String message = outputMessages.getBytesRef(i, new BytesRef()).utf8ToString();
183-
int id = outputIds.getInt(i);
184-
int valuesFromIndex = outputValues.getFirstValueIndex(i);
185-
int valuesToIndex = valuesFromIndex + outputValues.getValueCount(i);
186-
for (int valueIndex = valuesFromIndex; valueIndex < valuesToIndex; valueIndex++) {
187-
String value = outputValues.getBytesRef(valueIndex, new BytesRef()).utf8ToString();
188-
values.computeIfAbsent(message, key -> new HashMap<>()).computeIfAbsent(id, key -> new HashSet<>()).add(value);
201+
BytesRef messageBytesRef = ((BytesRef) BlockUtils.toJavaObject(outputMessages, i));
202+
String message = messageBytesRef == null ? null : messageBytesRef.utf8ToString();
203+
result.computeIfAbsent(message, key -> new HashMap<>());
204+
205+
Integer id = (Integer) BlockUtils.toJavaObject(outputIds, i);
206+
result.get(message).computeIfAbsent(id, key -> new HashSet<>());
207+
208+
Object values = BlockUtils.toJavaObject(outputValues, i);
209+
if (values == null) {
210+
result.get(message).get(id).add(null);
211+
} else {
212+
if ((values instanceof List) == false) {
213+
values = List.of(values);
214+
}
215+
for (Object valueObject : (List<?>) values) {
216+
BytesRef value = (BytesRef) valueObject;
217+
result.get(message).get(id).add(value.utf8ToString());
218+
}
189219
}
190220
}
191221
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
192222

193-
assertThat(
194-
values,
195-
equalTo(
196-
Map.of(
197-
".*?connected.+?to.*?",
198-
Map.of(
199-
7,
200-
Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"),
201-
42,
202-
Set.of("connected to 1.1.3"),
203-
111,
204-
Set.of("connected to 2.1.1")
205-
),
206-
".*?connection.+?error.*?",
207-
Map.of(7, Set.of("connection error"), 42, Set.of("connection error")),
208-
".*?disconnected.*?",
209-
Map.of(7, Set.of("disconnected"))
210-
)
211-
)
223+
Map<String, Map<Integer, Set<String>>> expectedResult = Map.of(
224+
".*?connected.+?to.*?",
225+
Map.of(
226+
7,
227+
Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"),
228+
42,
229+
Set.of("connected to 1.1.3"),
230+
111,
231+
Set.of("connected to 2.1.1")
232+
),
233+
".*?connection.+?error.*?",
234+
Map.of(7, Set.of("connection error"), 42, Set.of("connection error")),
235+
".*?disconnected.*?",
236+
Map.of(7, Set.of("disconnected"))
212237
);
238+
if (withNull) {
239+
expectedResult = new HashMap<>(expectedResult);
240+
expectedResult.put(null, new HashMap<>());
241+
expectedResult.get(null).put(null, new HashSet<>());
242+
expectedResult.get(null).get(null).add(null);
243+
expectedResult.get(null).put(43, new HashSet<>());
244+
expectedResult.get(null).get(43).add(null);
245+
}
246+
assertThat(result, equalTo(expectedResult));
213247
}
214248
}

0 commit comments

Comments
 (0)