Skip to content

Commit e6ac068

Browse files
committed
CategorizePackedValuesBlockHashTests
1 parent d7af8bf commit e6ac068

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.analysis.common.CommonAnalysisPlugin;
12+
import org.elasticsearch.common.breaker.CircuitBreaker;
13+
import org.elasticsearch.common.collect.Iterators;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.common.unit.ByteSizeValue;
16+
import org.elasticsearch.common.util.BigArrays;
17+
import org.elasticsearch.common.util.MockBigArrays;
18+
import org.elasticsearch.common.util.PageCacheRecycler;
19+
import org.elasticsearch.compute.aggregation.AggregatorMode;
20+
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
21+
import org.elasticsearch.compute.data.Block;
22+
import org.elasticsearch.compute.data.BlockFactory;
23+
import org.elasticsearch.compute.data.BytesRefBlock;
24+
import org.elasticsearch.compute.data.BytesRefVector;
25+
import org.elasticsearch.compute.data.ElementType;
26+
import org.elasticsearch.compute.data.IntBlock;
27+
import org.elasticsearch.compute.data.IntVector;
28+
import org.elasticsearch.compute.data.Page;
29+
import org.elasticsearch.compute.operator.CannedSourceOperator;
30+
import org.elasticsearch.compute.operator.Driver;
31+
import org.elasticsearch.compute.operator.DriverContext;
32+
import org.elasticsearch.compute.operator.HashAggregationOperator;
33+
import org.elasticsearch.compute.operator.LocalSourceOperator;
34+
import org.elasticsearch.compute.operator.PageConsumerOperator;
35+
import org.elasticsearch.core.Releasables;
36+
import org.elasticsearch.env.Environment;
37+
import org.elasticsearch.env.TestEnvironment;
38+
import org.elasticsearch.index.analysis.AnalysisRegistry;
39+
import org.elasticsearch.indices.analysis.AnalysisModule;
40+
import org.elasticsearch.plugins.scanners.StablePluginsRegistry;
41+
import org.elasticsearch.xpack.ml.MachineLearning;
42+
import org.junit.Before;
43+
44+
import java.io.IOException;
45+
import java.util.ArrayList;
46+
import java.util.HashMap;
47+
import java.util.HashSet;
48+
import java.util.List;
49+
import java.util.Map;
50+
import java.util.Set;
51+
52+
import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
53+
import static org.hamcrest.Matchers.equalTo;
54+
import static org.hamcrest.Matchers.hasSize;
55+
56+
public class CategorizePackedValuesBlockHashTests extends BlockHashTestCase {
57+
58+
private AnalysisRegistry analysisRegistry;
59+
60+
@Before
61+
private void initAnalysisRegistry() throws IOException {
62+
analysisRegistry = new AnalysisModule(
63+
TestEnvironment.newEnvironment(
64+
Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build()
65+
),
66+
List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()),
67+
new StablePluginsRegistry()
68+
).getAnalysisRegistry();
69+
}
70+
71+
public void testCategorize_withDriver() {
72+
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
73+
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
74+
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
75+
76+
List<BlockHash.GroupSpec> groupSpecs = List.of(
77+
new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true),
78+
new BlockHash.GroupSpec(1, ElementType.INT, false)
79+
);
80+
81+
LocalSourceOperator.BlockSupplier input1 = () -> {
82+
try (
83+
BytesRefVector.Builder messagesBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
84+
IntVector.Builder idsBuilder = driverContext.blockFactory().newIntVectorBuilder(10)
85+
) {
86+
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.1"));
87+
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.2"));
88+
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.3"));
89+
messagesBuilder.appendBytesRef(new BytesRef("connection error"));
90+
messagesBuilder.appendBytesRef(new BytesRef("connection error"));
91+
messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.4"));
92+
idsBuilder.appendInt(7);
93+
idsBuilder.appendInt(7);
94+
idsBuilder.appendInt(42);
95+
idsBuilder.appendInt(7);
96+
idsBuilder.appendInt(42);
97+
idsBuilder.appendInt(7);
98+
return new Block[] { messagesBuilder.build().asBlock(), idsBuilder.build().asBlock() };
99+
}
100+
};
101+
LocalSourceOperator.BlockSupplier input2 = () -> {
102+
try (
103+
BytesRefVector.Builder messagesBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
104+
IntVector.Builder idsBuilder = driverContext.blockFactory().newIntVectorBuilder(10)
105+
) {
106+
messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.1"));
107+
messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.2"));
108+
messagesBuilder.appendBytesRef(new BytesRef("disconnected"));
109+
messagesBuilder.appendBytesRef(new BytesRef("connection error"));
110+
idsBuilder.appendInt(111);
111+
idsBuilder.appendInt(7);
112+
idsBuilder.appendInt(7);
113+
idsBuilder.appendInt(42);
114+
return new Block[] { messagesBuilder.build().asBlock(), idsBuilder.build().asBlock() };
115+
}
116+
};
117+
118+
List<Page> intermediateOutput = new ArrayList<>();
119+
120+
Driver driver = new Driver(
121+
driverContext,
122+
new LocalSourceOperator(input1),
123+
List.of(
124+
new HashAggregationOperator.HashAggregationOperatorFactory(
125+
groupSpecs,
126+
AggregatorMode.INITIAL,
127+
List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(0)).groupingAggregatorFactory(AggregatorMode.INITIAL)),
128+
16 * 1024,
129+
analysisRegistry
130+
).get(driverContext)
131+
),
132+
new PageConsumerOperator(intermediateOutput::add),
133+
() -> {}
134+
);
135+
runDriver(driver);
136+
137+
driver = new Driver(
138+
driverContext,
139+
new LocalSourceOperator(input2),
140+
List.of(
141+
new HashAggregationOperator.HashAggregationOperatorFactory(
142+
groupSpecs,
143+
AggregatorMode.INITIAL,
144+
List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(0)).groupingAggregatorFactory(AggregatorMode.INITIAL)),
145+
16 * 1024,
146+
analysisRegistry
147+
).get(driverContext)
148+
),
149+
new PageConsumerOperator(intermediateOutput::add),
150+
() -> {}
151+
);
152+
runDriver(driver);
153+
154+
List<Page> finalOutput = new ArrayList<>();
155+
156+
driver = new Driver(
157+
driverContext,
158+
new CannedSourceOperator(intermediateOutput.iterator()),
159+
List.of(
160+
new HashAggregationOperator.HashAggregationOperatorFactory(
161+
groupSpecs,
162+
AggregatorMode.FINAL,
163+
List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(2)).groupingAggregatorFactory(AggregatorMode.FINAL)),
164+
16 * 1024,
165+
analysisRegistry
166+
).get(driverContext)
167+
),
168+
new PageConsumerOperator(finalOutput::add),
169+
() -> {}
170+
);
171+
runDriver(driver);
172+
173+
assertThat(finalOutput, hasSize(1));
174+
assertThat(finalOutput.get(0).getBlockCount(), equalTo(3));
175+
BytesRefBlock outputMessages = finalOutput.get(0).getBlock(0);
176+
IntBlock outputIds = finalOutput.get(0).getBlock(1);
177+
BytesRefBlock outputValues = finalOutput.get(0).getBlock(2);
178+
assertThat(outputIds.getPositionCount(), equalTo(outputMessages.getPositionCount()));
179+
assertThat(outputValues.getPositionCount(), equalTo(outputMessages.getPositionCount()));
180+
Map<String, Map<Integer, Set<String>>> values = new HashMap<>();
181+
for (int i = 0; i < outputMessages.getPositionCount(); i++) {
182+
String message = outputMessages.getBytesRef(i, new BytesRef()).utf8ToString();
183+
int id = outputIds.getInt(i);
184+
int valuesFromIndex = outputValues.getFirstValueIndex(i);
185+
int valuesToIndex = valuesFromIndex + outputValues.getValueCount(i);
186+
for (int valueIndex = valuesFromIndex; valueIndex < valuesToIndex; valueIndex++) {
187+
String value = outputValues.getBytesRef(valueIndex, new BytesRef()).utf8ToString();
188+
values.computeIfAbsent(message, key -> new HashMap<>()).computeIfAbsent(id, key -> new HashSet<>()).add(value);
189+
}
190+
}
191+
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
192+
193+
assertThat(
194+
values,
195+
equalTo(
196+
Map.of(
197+
".*?connected.+?to.*?",
198+
Map.of(
199+
7,
200+
Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"),
201+
42,
202+
Set.of("connected to 1.1.3"),
203+
111,
204+
Set.of("connected to 2.1.1")
205+
),
206+
".*?connection.+?error.*?",
207+
Map.of(7, Set.of("connection error"), 42, Set.of("connection error")),
208+
".*?disconnected.*?",
209+
Map.of(7, Set.of("disconnected"))
210+
)
211+
)
212+
);
213+
}
214+
}

0 commit comments

Comments
 (0)