Skip to content

Commit defe807

Browse files
committed
ES|QL categorize with multiple groupings.
1 parent 7ffac3b commit defe807

File tree

8 files changed

+289
-30
lines changed

8 files changed

+289
-30
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,15 @@ public static BlockHash buildCategorizeBlockHash(
174174
List<GroupSpec> groups,
175175
AggregatorMode aggregatorMode,
176176
BlockFactory blockFactory,
177-
AnalysisRegistry analysisRegistry
177+
AnalysisRegistry analysisRegistry,
178+
int emitBatchSize
178179
) {
179-
if (groups.size() != 1) {
180-
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
180+
if (groups.size() == 1) {
181+
return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
182+
} else {
183+
assert groups.get(0).isCategorize();
184+
return new CategorizePackedValuesBlockHash(groups, blockFactory, aggregatorMode, analysisRegistry, emitBatchSize);
181185
}
182-
183-
return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
184186
}
185187

186188
/**

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import java.util.Objects;
4545

4646
/**
47-
* Base BlockHash implementation for {@code Categorize} grouping function.
47+
* BlockHash implementation for {@code Categorize} grouping function.
4848
*/
4949
public class CategorizeBlockHash extends BlockHash {
5050

@@ -95,12 +95,14 @@ public class CategorizeBlockHash extends BlockHash {
9595
}
9696
}
9797

98+
boolean seenNull() {
99+
return seenNull;
100+
}
101+
98102
@Override
99103
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
100-
if (aggregatorMode.isInputPartial() == false) {
101-
addInitial(page, addInput);
102-
} else {
103-
addIntermediate(page, addInput);
104+
try (IntBlock block = add(page)) {
105+
addInput.add(0, block);
104106
}
105107
}
106108

@@ -129,29 +131,28 @@ public void close() {
129131
Releasables.close(evaluator, categorizer);
130132
}
131133

134+
private IntBlock add(Page page) {
135+
return aggregatorMode.isInputPartial() == false ? addInitial(page) : addIntermediate(page);
136+
}
137+
132138
/**
133139
* Adds initial (raw) input to the state.
134140
*/
135-
private void addInitial(Page page, GroupingAggregatorFunction.AddInput addInput) {
136-
try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel))) {
137-
addInput.add(0, result);
138-
}
141+
IntBlock addInitial(Page page) {
142+
return (IntBlock) evaluator.eval(page.getBlock(channel));
139143
}
140144

141145
/**
142146
* Adds intermediate state to the state.
143147
*/
144-
private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addInput) {
148+
private IntBlock addIntermediate(Page page) {
145149
if (page.getPositionCount() == 0) {
146-
return;
150+
return null;
147151
}
148152
BytesRefBlock categorizerState = page.getBlock(channel);
149153
if (categorizerState.areAllValuesNull()) {
150154
seenNull = true;
151-
try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
152-
addInput.add(0, newIds);
153-
}
154-
return;
155+
return blockFactory.newConstantIntBlockWith(NULL_ORD, 1);
155156
}
156157

157158
Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
@@ -161,9 +162,7 @@ private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addI
161162
for (int i = fromId; i < toId; i++) {
162163
newIdsBuilder.appendInt(idMap.get(i));
163164
}
164-
try (IntBlock newIds = newIdsBuilder.build()) {
165-
addInput.add(0, newIds);
166-
}
165+
return newIdsBuilder.build();
167166
}
168167
}
169168

@@ -172,7 +171,7 @@ private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addI
172171
*
173172
* @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
174173
*/
175-
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
174+
Map<Integer, Integer> readIntermediate(BytesRef bytes) {
176175
Map<Integer, Integer> idMap = new HashMap<>();
177176
try (StreamInput in = new BytesArray(bytes).streamInput()) {
178177
if (in.readBoolean()) {
@@ -198,15 +197,19 @@ private Block buildIntermediateBlock() {
198197
if (categorizer.getCategoryCount() == 0) {
199198
return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
200199
}
200+
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
201+
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
202+
return blockFactory.newConstantBytesRefBlockWith(serializeCategorizer(), positionCount);
203+
}
204+
205+
BytesRef serializeCategorizer() {
201206
try (BytesStreamOutput out = new BytesStreamOutput()) {
202207
out.writeBoolean(seenNull);
203208
out.writeVInt(categorizer.getCategoryCount());
204209
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
205210
category.writeTo(out);
206211
}
207-
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
208-
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
209-
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
212+
return out.bytes().toBytesRef();
210213
} catch (IOException e) {
211214
throw new RuntimeException(e);
212215
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.common.bytes.BytesArray;
12+
import org.elasticsearch.common.io.stream.BytesStreamOutput;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.unit.ByteSizeValue;
15+
import org.elasticsearch.common.util.BigArrays;
16+
import org.elasticsearch.common.util.BitArray;
17+
import org.elasticsearch.compute.aggregation.AggregatorMode;
18+
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
19+
import org.elasticsearch.compute.data.Block;
20+
import org.elasticsearch.compute.data.BlockFactory;
21+
import org.elasticsearch.compute.data.BytesRefBlock;
22+
import org.elasticsearch.compute.data.ElementType;
23+
import org.elasticsearch.compute.data.IntBlock;
24+
import org.elasticsearch.compute.data.IntVector;
25+
import org.elasticsearch.compute.data.Page;
26+
import org.elasticsearch.core.ReleasableIterator;
27+
import org.elasticsearch.core.Releasables;
28+
import org.elasticsearch.index.analysis.AnalysisRegistry;
29+
30+
import java.io.IOException;
31+
import java.util.ArrayList;
32+
import java.util.List;
33+
import java.util.Map;
34+
35+
/**
36+
* BlockHash implementation for {@code Categorize} grouping function as first
37+
* grouping expression, followed by one or mode other grouping expressions.
38+
*/
39+
public class CategorizePackedValuesBlockHash extends BlockHash {
40+
41+
private final AggregatorMode aggregatorMode;
42+
private final List<GroupSpec> specs;
43+
private final CategorizeBlockHash categorizeBlockHash;
44+
private final PackedValuesBlockHash packedValuesBlockHash;
45+
46+
CategorizePackedValuesBlockHash(
47+
List<GroupSpec> specs,
48+
BlockFactory blockFactory,
49+
AggregatorMode aggregatorMode,
50+
AnalysisRegistry analysisRegistry,
51+
int emitBatchSize
52+
) {
53+
super(blockFactory);
54+
this.aggregatorMode = aggregatorMode;
55+
this.specs = specs;
56+
categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry);
57+
58+
List<GroupSpec> newSpecs = new ArrayList<>(specs);
59+
newSpecs.set(0, new GroupSpec(-1, ElementType.INT));
60+
packedValuesBlockHash = new PackedValuesBlockHash(newSpecs, blockFactory, emitBatchSize);
61+
62+
// TODO: close stuff upon failure
63+
}
64+
65+
@Override
66+
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
67+
try (IntBlock categories = getCategories(page)) {
68+
packedValuesBlockHash.add(page.appendBlock(categories), addInput);
69+
}
70+
}
71+
72+
private IntBlock getCategories(Page page) {
73+
if (aggregatorMode.isInputPartial() == false) {
74+
return categorizeBlockHash.addInitial(page);
75+
} else {
76+
BytesRefBlock stateBlock = page.getBlock(0);
77+
BytesRef stateBytes = stateBlock.getBytesRef(0, new BytesRef());
78+
79+
try (StreamInput in = new BytesArray(stateBytes).streamInput()) {
80+
BytesRef categorizerState = in.readBytesRef();
81+
Map<Integer, Integer> idMap = categorizeBlockHash.readIntermediate(categorizerState);
82+
int[] oldIds = in.readIntArray();
83+
try (IntBlock.Builder newIds = blockFactory.newIntBlockBuilder(page.getPositionCount())) {
84+
for (int oldId : oldIds) {
85+
newIds.appendInt(idMap.get(oldId));
86+
}
87+
return newIds.build();
88+
}
89+
} catch (IOException e) {
90+
throw new RuntimeException(e);
91+
}
92+
}
93+
}
94+
95+
@Override
96+
public Block[] getKeys() {
97+
Block[] keys = packedValuesBlockHash.getKeys();
98+
if (aggregatorMode.isOutputPartial() == false) {
99+
try (
100+
BytesRefBlock regexes = (BytesRefBlock) categorizeBlockHash.getKeys()[0];
101+
BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(keys[0].getPositionCount())
102+
) {
103+
IntVector idsVector = (IntVector) keys[0].asVector();
104+
int idsOffset = categorizeBlockHash.seenNull() ? 0 : -1;
105+
BytesRef scratch = new BytesRef();
106+
for (int i = 0; i < idsVector.getPositionCount(); i++) {
107+
int id = idsVector.getInt(i);
108+
if (id == 0) {
109+
builder.appendNull();
110+
} else {
111+
builder.appendBytesRef(regexes.getBytesRef(id + idsOffset, scratch));
112+
}
113+
}
114+
keys[0].close();
115+
keys[0] = builder.build();
116+
}
117+
} else {
118+
BytesRef state;
119+
try (BytesStreamOutput out = new BytesStreamOutput()) {
120+
out.writeBytesRef(categorizeBlockHash.serializeCategorizer());
121+
IntVector idsVector = (IntVector) keys[0].asVector();
122+
int[] idsArray = new int[idsVector.getPositionCount()];
123+
for (int i = 0; i < idsVector.getPositionCount(); i++) {
124+
idsArray[i] = idsVector.getInt(i);
125+
}
126+
out.writeIntArray(idsArray);
127+
state = out.bytes().toBytesRef();
128+
} catch (IOException e) {
129+
throw new RuntimeException(e);
130+
}
131+
keys[0].close();
132+
keys[0] = blockFactory.newConstantBytesRefBlockWith(state, keys[0].getPositionCount());
133+
}
134+
return keys;
135+
}
136+
137+
@Override
138+
public IntVector nonEmpty() {
139+
return packedValuesBlockHash.nonEmpty();
140+
}
141+
142+
@Override
143+
public BitArray seenGroupIds(BigArrays bigArrays) {
144+
return packedValuesBlockHash.seenGroupIds(bigArrays);
145+
}
146+
147+
@Override
148+
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
149+
throw new UnsupportedOperationException();
150+
}
151+
152+
@Override
153+
public void close() {
154+
Releasables.close(categorizeBlockHash, packedValuesBlockHash);
155+
}
156+
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ public <B extends Block> B getBlock(int blockIndex) {
131131
if (blocksReleased) {
132132
throw new IllegalStateException("can't read released page");
133133
}
134+
if (blockIndex < 0) {
135+
blockIndex += blocks.length;
136+
}
134137
@SuppressWarnings("unchecked")
135138
B block = (B) blocks[blockIndex];
136139
if (block.isReleased()) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ public Operator get(DriverContext driverContext) {
5151
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
5252
return new HashAggregationOperator(
5353
aggregators,
54-
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
54+
() -> BlockHash.buildCategorizeBlockHash(
55+
groups,
56+
aggregatorMode,
57+
driverContext.blockFactory(),
58+
analysisRegistry,
59+
maxPageSize
60+
),
5561
driverContext
5662
);
5763
}

x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,84 @@ COUNT():long | x:keyword
592592
3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?]
593593
1 | [.*?Disconnected.*?,.*?Disconnected.*?]
594594
;
595+
596+
multiple groupings with categorize and ip
597+
required_capability: categorize_multiple_groupings
598+
599+
FROM sample_data
600+
| STATS count=COUNT() BY category=CATEGORIZE(message), client_ip
601+
| SORT category, client_ip
602+
;
603+
604+
count:long | category:keyword | client_ip:ip
605+
1 | .*?Connected.+?to.*? | 172.21.2.113
606+
1 | .*?Connected.+?to.*? | 172.21.2.162
607+
1 | .*?Connected.+?to.*? | 172.21.3.15
608+
3 | .*?Connection.+?error.*? | 172.21.3.15
609+
1 | .*?Disconnected.*? | 172.21.0.5
610+
;
611+
612+
multiple groupings with categorize and bucketed timestamp
613+
required_capability: categorize_multiple_groupings
614+
615+
FROM sample_data
616+
| STATS count=COUNT() BY category=CATEGORIZE(message), timestamp=BUCKET(@timestamp, 1 HOUR)
617+
| SORT category, timestamp
618+
;
619+
620+
count:long | category:keyword | timestamp:datetime
621+
2 | .*?Connected.+?to.*? | 2023-10-23T12:00:00.000Z
622+
1 | .*?Connected.+?to.*? | 2023-10-23T13:00:00.000Z
623+
3 | .*?Connection.+?error.*? | 2023-10-23T13:00:00.000Z
624+
1 | .*?Disconnected.*? | 2023-10-23T13:00:00.000Z
625+
;
626+
627+
multiple groupings with categorize and nulls
628+
required_capability: categorize_multiple_groupings
629+
630+
FROM employees
631+
| STATS SUM(languages) BY category=CATEGORIZE(job_positions), gender
632+
| SORT category DESC, gender ASC
633+
| LIMIT 5
634+
;
635+
636+
SUM(languages):long | category:keyword | gender:keyword
637+
11 | null | F
638+
16 | null | M
639+
14 | .*?Tech.+?Lead.*? | F
640+
23 | .*?Tech.+?Lead.*? | M
641+
9 | .*?Tech.+?Lead.*? | null
642+
;
643+
644+
multiple groupings with categorize and a field that's always null
645+
required_capability: categorize_multiple_groupings
646+
647+
FROM sample_data
648+
| EVAL nullfield = null
649+
| STATS count=COUNT() BY category=CATEGORIZE(nullfield), client_ip
650+
| SORT client_ip
651+
;
652+
653+
count:long | category:keyword | client_ip:ip
654+
1 | null | 172.21.0.5
655+
1 | null | 172.21.2.113
656+
1 | null | 172.21.2.162
657+
4 | null | 172.21.3.15
658+
;
659+
660+
661+
multiple groupings with categorize and the same text field
662+
required_capability: categorize_multiple_groupings
663+
664+
FROM sample_data
665+
| STATS count=COUNT() BY category=CATEGORIZE(message), message
666+
| SORT message
667+
;
668+
669+
count:long | category:keyword | message:keyword
670+
1 | .*?Connected.+?to.*? | Connected to 10.1.0.1
671+
1 | .*?Connected.+?to.*? | Connected to 10.1.0.2
672+
1 | .*?Connected.+?to.*? | Connected to 10.1.0.3
673+
3 | .*?Connection.+?error.*? | Connection error
674+
1 | .*?Disconnected.*? | Disconnected
675+
;

0 commit comments

Comments
 (0)