Skip to content

Commit 36d11d3

Browse files
ES|QL categorize with multiple groupings (#118173) (#118590)
* ES|QL categorize with multiple groupings. * Fix VerifierTests * Close stuff when constructing CategorizePackedValuesBlockHash fails * CategorizePackedValuesBlockHashTests * Improve categorize javadocs * Update docs/changelog/118173.yaml * Create CategorizePackedValuesBlockHash's deletegate page differently * Double check in BlockHash builder for single categorize * Reuse blocks array * More CSV tests * Remove assumeTrue categorize_v5 * Rename test * Two more verifier tests * more CSV tests * Add JavaDocs/comments * spotless * Refactor/unify recategorize * Better memory accounting * fix csv test * randomize CategorizePackedValuesBlockHashTests * Add TODO Co-authored-by: Elastic Machine <[email protected]>
1 parent ba0f76a commit 36d11d3

File tree

13 files changed

+676
-76
lines changed

13 files changed

+676
-76
lines changed

docs/changelog/118173.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 118173
2+
summary: ES|QL categorize with multiple groupings
3+
area: Machine Learning
4+
type: feature
5+
issues: []

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,16 @@ public static BlockHash buildCategorizeBlockHash(
180180
List<GroupSpec> groups,
181181
AggregatorMode aggregatorMode,
182182
BlockFactory blockFactory,
183-
AnalysisRegistry analysisRegistry
183+
AnalysisRegistry analysisRegistry,
184+
int emitBatchSize
184185
) {
185-
if (groups.size() != 1) {
186-
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
186+
if (groups.size() == 1) {
187+
return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
188+
} else {
189+
assert groups.get(0).isCategorize();
190+
assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize);
191+
return new CategorizePackedValuesBlockHash(groups, blockFactory, aggregatorMode, analysisRegistry, emitBatchSize);
187192
}
188-
189-
return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
190193
}
191194

192195
/**

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import java.util.Objects;
4545

4646
/**
47-
* Base BlockHash implementation for {@code Categorize} grouping function.
47+
* BlockHash implementation for {@code Categorize} grouping function.
4848
*/
4949
public class CategorizeBlockHash extends BlockHash {
5050

@@ -53,11 +53,9 @@ public class CategorizeBlockHash extends BlockHash {
5353
);
5454
private static final int NULL_ORD = 0;
5555

56-
// TODO: this should probably also take an emitBatchSize
5756
private final int channel;
5857
private final AggregatorMode aggregatorMode;
5958
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
60-
6159
private final CategorizeEvaluator evaluator;
6260

6361
/**
@@ -95,12 +93,14 @@ public class CategorizeBlockHash extends BlockHash {
9593
}
9694
}
9795

96+
boolean seenNull() {
97+
return seenNull;
98+
}
99+
98100
@Override
99101
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
100-
if (aggregatorMode.isInputPartial() == false) {
101-
addInitial(page, addInput);
102-
} else {
103-
addIntermediate(page, addInput);
102+
try (IntBlock block = add(page)) {
103+
addInput.add(0, block);
104104
}
105105
}
106106

@@ -129,50 +129,38 @@ public void close() {
129129
Releasables.close(evaluator, categorizer);
130130
}
131131

132+
private IntBlock add(Page page) {
133+
return aggregatorMode.isInputPartial() == false ? addInitial(page) : addIntermediate(page);
134+
}
135+
132136
/**
133137
* Adds initial (raw) input to the state.
134138
*/
135-
private void addInitial(Page page, GroupingAggregatorFunction.AddInput addInput) {
136-
try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel))) {
137-
addInput.add(0, result);
138-
}
139+
IntBlock addInitial(Page page) {
140+
return (IntBlock) evaluator.eval(page.getBlock(channel));
139141
}
140142

141143
/**
142144
* Adds intermediate state to the state.
143145
*/
144-
private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addInput) {
146+
private IntBlock addIntermediate(Page page) {
145147
if (page.getPositionCount() == 0) {
146-
return;
148+
return null;
147149
}
148150
BytesRefBlock categorizerState = page.getBlock(channel);
149151
if (categorizerState.areAllValuesNull()) {
150152
seenNull = true;
151-
try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
152-
addInput.add(0, newIds);
153-
}
154-
return;
155-
}
156-
157-
Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
158-
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
159-
int fromId = idMap.containsKey(0) ? 0 : 1;
160-
int toId = fromId + idMap.size();
161-
for (int i = fromId; i < toId; i++) {
162-
newIdsBuilder.appendInt(idMap.get(i));
163-
}
164-
try (IntBlock newIds = newIdsBuilder.build()) {
165-
addInput.add(0, newIds);
166-
}
153+
return blockFactory.newConstantIntBlockWith(NULL_ORD, 1);
167154
}
155+
return recategorize(categorizerState.getBytesRef(0, new BytesRef()), null).asBlock();
168156
}
169157

170158
/**
171-
* Read intermediate state from a block.
172-
*
173-
* @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
159+
* Reads the intermediate state from a block and recategorizes the provided IDs.
160+
* If no IDs are provided, the IDs are the IDs in the categorizer's state in order.
161+
* (So 0...N-1 or 1...N, depending on whether null is present.)
174162
*/
175-
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
163+
IntVector recategorize(BytesRef bytes, IntVector ids) {
176164
Map<Integer, Integer> idMap = new HashMap<>();
177165
try (StreamInput in = new BytesArray(bytes).streamInput()) {
178166
if (in.readBoolean()) {
@@ -185,10 +173,22 @@ private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
185173
// +1 because the 0 ordinal is reserved for null
186174
idMap.put(oldCategoryId + 1, newCategoryId + 1);
187175
}
188-
return idMap;
189176
} catch (IOException e) {
190177
throw new RuntimeException(e);
191178
}
179+
try (IntVector.Builder newIdsBuilder = blockFactory.newIntVectorBuilder(idMap.size())) {
180+
if (ids == null) {
181+
int idOffset = idMap.containsKey(0) ? 0 : 1;
182+
for (int i = 0; i < idMap.size(); i++) {
183+
newIdsBuilder.appendInt(idMap.get(i + idOffset));
184+
}
185+
} else {
186+
for (int i = 0; i < ids.getPositionCount(); i++) {
187+
newIdsBuilder.appendInt(idMap.get(ids.getInt(i)));
188+
}
189+
}
190+
return newIdsBuilder.build();
191+
}
192192
}
193193

194194
/**
@@ -198,15 +198,20 @@ private Block buildIntermediateBlock() {
198198
if (categorizer.getCategoryCount() == 0) {
199199
return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
200200
}
201+
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
202+
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
203+
return blockFactory.newConstantBytesRefBlockWith(serializeCategorizer(), positionCount);
204+
}
205+
206+
BytesRef serializeCategorizer() {
207+
// TODO: This BytesStreamOutput is not accounted for by the circuit breaker. Fix that!
201208
try (BytesStreamOutput out = new BytesStreamOutput()) {
202209
out.writeBoolean(seenNull);
203210
out.writeVInt(categorizer.getCategoryCount());
204211
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
205212
category.writeTo(out);
206213
}
207-
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
208-
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
209-
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
214+
return out.bytes().toBytesRef();
210215
} catch (IOException e) {
211216
throw new RuntimeException(e);
212217
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.common.bytes.BytesArray;
12+
import org.elasticsearch.common.io.stream.BytesStreamOutput;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.unit.ByteSizeValue;
15+
import org.elasticsearch.common.util.BigArrays;
16+
import org.elasticsearch.common.util.BitArray;
17+
import org.elasticsearch.compute.aggregation.AggregatorMode;
18+
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
19+
import org.elasticsearch.compute.data.Block;
20+
import org.elasticsearch.compute.data.BlockFactory;
21+
import org.elasticsearch.compute.data.BytesRefBlock;
22+
import org.elasticsearch.compute.data.ElementType;
23+
import org.elasticsearch.compute.data.IntBlock;
24+
import org.elasticsearch.compute.data.IntVector;
25+
import org.elasticsearch.compute.data.Page;
26+
import org.elasticsearch.core.ReleasableIterator;
27+
import org.elasticsearch.core.Releasables;
28+
import org.elasticsearch.index.analysis.AnalysisRegistry;
29+
30+
import java.io.IOException;
31+
import java.util.ArrayList;
32+
import java.util.List;
33+
34+
/**
35+
* BlockHash implementation for {@code Categorize} grouping function as first
36+
* grouping expression, followed by one or mode other grouping expressions.
37+
* <p>
38+
* For the first grouping (the {@code Categorize} grouping function), a
39+
* {@code CategorizeBlockHash} is used, which outputs integers (category IDs).
40+
* Next, a {@code PackedValuesBlockHash} is used on the category IDs and the
41+
* other groupings (which are not {@code Categorize}s).
42+
*/
43+
public class CategorizePackedValuesBlockHash extends BlockHash {
44+
45+
private final List<GroupSpec> specs;
46+
private final AggregatorMode aggregatorMode;
47+
private final Block[] blocks;
48+
private final CategorizeBlockHash categorizeBlockHash;
49+
private final PackedValuesBlockHash packedValuesBlockHash;
50+
51+
CategorizePackedValuesBlockHash(
52+
List<GroupSpec> specs,
53+
BlockFactory blockFactory,
54+
AggregatorMode aggregatorMode,
55+
AnalysisRegistry analysisRegistry,
56+
int emitBatchSize
57+
) {
58+
super(blockFactory);
59+
this.specs = specs;
60+
this.aggregatorMode = aggregatorMode;
61+
blocks = new Block[specs.size()];
62+
63+
List<GroupSpec> delegateSpecs = new ArrayList<>();
64+
delegateSpecs.add(new GroupSpec(0, ElementType.INT));
65+
for (int i = 1; i < specs.size(); i++) {
66+
delegateSpecs.add(new GroupSpec(i, specs.get(i).elementType()));
67+
}
68+
69+
boolean success = false;
70+
try {
71+
categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry);
72+
packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize);
73+
success = true;
74+
} finally {
75+
if (success == false) {
76+
close();
77+
}
78+
}
79+
}
80+
81+
@Override
82+
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
83+
try (IntBlock categories = getCategories(page)) {
84+
blocks[0] = categories;
85+
for (int i = 1; i < specs.size(); i++) {
86+
blocks[i] = page.getBlock(specs.get(i).channel());
87+
}
88+
packedValuesBlockHash.add(new Page(blocks), addInput);
89+
}
90+
}
91+
92+
private IntBlock getCategories(Page page) {
93+
if (aggregatorMode.isInputPartial() == false) {
94+
return categorizeBlockHash.addInitial(page);
95+
} else {
96+
BytesRefBlock stateBlock = page.getBlock(0);
97+
BytesRef stateBytes = stateBlock.getBytesRef(0, new BytesRef());
98+
try (StreamInput in = new BytesArray(stateBytes).streamInput()) {
99+
BytesRef categorizerState = in.readBytesRef();
100+
try (IntVector ids = IntVector.readFrom(blockFactory, in)) {
101+
return categorizeBlockHash.recategorize(categorizerState, ids).asBlock();
102+
}
103+
} catch (IOException e) {
104+
throw new RuntimeException(e);
105+
}
106+
}
107+
}
108+
109+
@Override
110+
public Block[] getKeys() {
111+
Block[] keys = packedValuesBlockHash.getKeys();
112+
if (aggregatorMode.isOutputPartial() == false) {
113+
// For final output, the keys are the category regexes.
114+
try (
115+
BytesRefBlock regexes = (BytesRefBlock) categorizeBlockHash.getKeys()[0];
116+
BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(keys[0].getPositionCount())
117+
) {
118+
IntVector idsVector = (IntVector) keys[0].asVector();
119+
int idsOffset = categorizeBlockHash.seenNull() ? 0 : -1;
120+
BytesRef scratch = new BytesRef();
121+
for (int i = 0; i < idsVector.getPositionCount(); i++) {
122+
int id = idsVector.getInt(i);
123+
if (id == 0) {
124+
builder.appendNull();
125+
} else {
126+
builder.appendBytesRef(regexes.getBytesRef(id + idsOffset, scratch));
127+
}
128+
}
129+
keys[0].close();
130+
keys[0] = builder.build();
131+
}
132+
} else {
133+
// For intermediate output, the keys are the delegate PackedValuesBlockHash's
134+
// keys, with the category IDs replaced by the categorizer's internal state
135+
// together with the list of category IDs.
136+
BytesRef state;
137+
// TODO: This BytesStreamOutput is not accounted for by the circuit breaker. Fix that!
138+
try (BytesStreamOutput out = new BytesStreamOutput()) {
139+
out.writeBytesRef(categorizeBlockHash.serializeCategorizer());
140+
((IntVector) keys[0].asVector()).writeTo(out);
141+
state = out.bytes().toBytesRef();
142+
} catch (IOException e) {
143+
throw new RuntimeException(e);
144+
}
145+
keys[0].close();
146+
keys[0] = blockFactory.newConstantBytesRefBlockWith(state, keys[0].getPositionCount());
147+
}
148+
return keys;
149+
}
150+
151+
@Override
152+
public IntVector nonEmpty() {
153+
return packedValuesBlockHash.nonEmpty();
154+
}
155+
156+
@Override
157+
public BitArray seenGroupIds(BigArrays bigArrays) {
158+
return packedValuesBlockHash.seenGroupIds(bigArrays);
159+
}
160+
161+
@Override
162+
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
163+
throw new UnsupportedOperationException();
164+
}
165+
166+
@Override
167+
public void close() {
168+
Releasables.close(categorizeBlockHash, packedValuesBlockHash);
169+
}
170+
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ public Operator get(DriverContext driverContext) {
5151
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
5252
return new HashAggregationOperator(
5353
aggregators,
54-
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
54+
() -> BlockHash.buildCategorizeBlockHash(
55+
groups,
56+
aggregatorMode,
57+
driverContext.blockFactory(),
58+
analysisRegistry,
59+
maxPageSize
60+
),
5561
driverContext
5662
);
5763
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,6 @@ public void close() {
130130
} finally {
131131
page.releaseBlocks();
132132
}
133-
134-
// TODO: randomize values? May give wrong results
135-
// TODO: assert the categorizer state after adding pages.
136133
}
137134

138135
public void testCategorizeRawMultivalue() {

0 commit comments

Comments
 (0)