Skip to content

Commit 6bdc63a

Browse files
committed
ES|QL categorize options
1 parent 88d765d commit 6bdc63a

File tree

24 files changed

+359
-54
lines changed

24 files changed

+359
-54
lines changed

docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/categorize.md

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/_snippets/functions/layout/categorize.md

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/_snippets/functions/parameters/categorize.md

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/_snippets/functions/types/categorize.md

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/images/functions/categorize.svg

Lines changed: 1 addition & 1 deletion
Loading

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ static TransportVersion def(int id) {
338338
public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00);
339339
public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
340340
public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
341+
public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
341342

342343
/*
343344
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,26 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
128128
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}
129129

130130
/**
131-
* @param isCategorize Whether this group is a CATEGORIZE() or not.
132-
* May be changed in the future when more stateful grouping functions are added.
131+
* Configuration for a BlockHash group spec that is doing text categorization.
133132
*/
134-
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) {
133+
public record CategorizeDef(String analyzer, OutputFormat outputFormat, int similarityThreshold) {
134+
public enum OutputFormat {
135+
REGEX,
136+
TOKENS
137+
}
138+
}
139+
140+
public record GroupSpec(int channel, ElementType elementType, @Nullable CategorizeDef categorizeDef, @Nullable TopNDef topNDef) {
135141
public GroupSpec(int channel, ElementType elementType) {
136-
this(channel, elementType, false, null);
142+
this(channel, elementType, null, null);
143+
}
144+
145+
public GroupSpec(int channel, ElementType elementType, CategorizeDef categorizeDef) {
146+
this(channel, elementType, categorizeDef, null);
137147
}
138148

139-
public GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
140-
this(channel, elementType, isCategorize, null);
149+
public boolean isCategorize() {
150+
return categorizeDef != null;
141151
}
142152
}
143153

@@ -207,7 +217,13 @@ public static BlockHash buildCategorizeBlockHash(
207217
int emitBatchSize
208218
) {
209219
if (groups.size() == 1) {
210-
return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
220+
return new CategorizeBlockHash(
221+
blockFactory,
222+
groups.get(0).channel,
223+
aggregatorMode,
224+
groups.get(0).categorizeDef,
225+
analysisRegistry
226+
);
211227
} else {
212228
assert groups.get(0).isCategorize();
213229
assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize);

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.elasticsearch.common.util.BytesRefHash;
1919
import org.elasticsearch.compute.aggregation.AggregatorMode;
2020
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
21-
import org.elasticsearch.compute.aggregation.SeenGroupIds;
2221
import org.elasticsearch.compute.data.Block;
2322
import org.elasticsearch.compute.data.BlockFactory;
2423
import org.elasticsearch.compute.data.BytesRefBlock;
@@ -47,12 +46,13 @@
4746
*/
4847
public class CategorizeBlockHash extends BlockHash {
4948

50-
private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig
49+
private static final CategorizationAnalyzerConfig DEFAULT_ANALYZER_CONFIG = CategorizationAnalyzerConfig
5150
.buildStandardEsqlCategorizationAnalyzer();
5251
private static final int NULL_ORD = 0;
5352

5453
private final int channel;
5554
private final AggregatorMode aggregatorMode;
55+
private final CategorizeDef categorizeDef;
5656
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
5757
private final CategorizeEvaluator evaluator;
5858

@@ -64,28 +64,38 @@ public class CategorizeBlockHash extends BlockHash {
6464
*/
6565
private boolean seenNull = false;
6666

67-
CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) {
67+
CategorizeBlockHash(
68+
BlockFactory blockFactory,
69+
int channel,
70+
AggregatorMode aggregatorMode,
71+
CategorizeDef categorizeDef,
72+
AnalysisRegistry analysisRegistry
73+
) {
6874
super(blockFactory);
6975

7076
this.channel = channel;
7177
this.aggregatorMode = aggregatorMode;
78+
this.categorizeDef = categorizeDef;
7279

7380
this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
7481
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
7582
CategorizationPartOfSpeechDictionary.getInstance(),
76-
0.70f
83+
categorizeDef.similarityThreshold() / 100.0f
7784
);
7885

7986
if (aggregatorMode.isInputPartial() == false) {
80-
CategorizationAnalyzer analyzer;
87+
CategorizationAnalyzer categorizationAnalyzer;
8188
try {
8289
Objects.requireNonNull(analysisRegistry);
83-
analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
84-
} catch (Exception e) {
90+
CategorizationAnalyzerConfig config = categorizeDef.analyzer() == null
91+
? DEFAULT_ANALYZER_CONFIG
92+
: new CategorizationAnalyzerConfig.Builder().setAnalyzer(categorizeDef.analyzer()).build();
93+
categorizationAnalyzer = new CategorizationAnalyzer(analysisRegistry, config);
94+
} catch (IOException e) {
8595
categorizer.close();
8696
throw new RuntimeException(e);
8797
}
88-
this.evaluator = new CategorizeEvaluator(analyzer);
98+
this.evaluator = new CategorizeEvaluator(categorizationAnalyzer);
8999
} else {
90100
this.evaluator = null;
91101
}
@@ -114,7 +124,7 @@ public IntVector nonEmpty() {
114124

115125
@Override
116126
public BitArray seenGroupIds(BigArrays bigArrays) {
117-
return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
127+
return new Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
118128
}
119129

120130
@Override
@@ -222,7 +232,7 @@ private Block buildFinalBlock() {
222232
try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
223233
result.appendNull();
224234
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
225-
scratch.copyChars(category.getRegex());
235+
scratch.copyChars(getKeyString(category));
226236
result.appendBytesRef(scratch.get());
227237
scratch.clear();
228238
}
@@ -232,14 +242,21 @@ private Block buildFinalBlock() {
232242

233243
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
234244
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
235-
scratch.copyChars(category.getRegex());
245+
scratch.copyChars(getKeyString(category));
236246
result.appendBytesRef(scratch.get());
237247
scratch.clear();
238248
}
239249
return result.build().asBlock();
240250
}
241251
}
242252

253+
private String getKeyString(SerializableTokenListCategory category) {
254+
return switch (categorizeDef.outputFormat()) {
255+
case REGEX -> category.getRegex();
256+
case TOKENS -> category.getKeyTokensString();
257+
};
258+
}
259+
243260
/**
244261
* Similar implementation to an Evaluator.
245262
*/

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ public class CategorizePackedValuesBlockHash extends BlockHash {
6868

6969
boolean success = false;
7070
try {
71-
categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry);
71+
categorizeBlockHash = new CategorizeBlockHash(
72+
blockFactory,
73+
specs.get(0).channel(),
74+
aggregatorMode,
75+
specs.get(0).categorizeDef(),
76+
analysisRegistry
77+
);
7278
packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize);
7379
success = true;
7480
} finally {

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363

6464
public class CategorizeBlockHashTests extends BlockHashTestCase {
6565

66+
private static final BlockHash.CategorizeDef CATEGORIZE_DEF = new BlockHash.CategorizeDef(
67+
null,
68+
BlockHash.CategorizeDef.OutputFormat.REGEX,
69+
70
70+
);
71+
6672
private AnalysisRegistry analysisRegistry;
6773

6874
@Before
@@ -98,7 +104,7 @@ public void testCategorizeRaw() {
98104
page = new Page(builder.build());
99105
}
100106

101-
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
107+
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, CATEGORIZE_DEF, analysisRegistry)) {
102108
for (int i = randomInt(2); i < 3; i++) {
103109
hash.add(page, new GroupingAggregatorFunction.AddInput() {
104110
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -170,7 +176,7 @@ public void testCategorizeRawMultivalue() {
170176
page = new Page(builder.build());
171177
}
172178

173-
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
179+
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, CATEGORIZE_DEF, analysisRegistry)) {
174180
for (int i = randomInt(2); i < 3; i++) {
175181
hash.add(page, new GroupingAggregatorFunction.AddInput() {
176182
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -259,8 +265,8 @@ public void testCategorizeIntermediate() {
259265

260266
// Fill intermediatePages with the intermediate state from the raw hashes
261267
try (
262-
BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry);
263-
BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry);
268+
BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, CATEGORIZE_DEF, analysisRegistry);
269+
BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, CATEGORIZE_DEF, analysisRegistry);
264270
) {
265271
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
266272
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -335,7 +341,7 @@ public void close() {
335341
page2.releaseBlocks();
336342
}
337343

338-
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) {
344+
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, CATEGORIZE_DEF, null)) {
339345
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
340346
private void addBlock(int positionOffset, IntBlock groupIds) {
341347
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
@@ -590,7 +596,7 @@ public void testCategorize_withDriver() {
590596
}
591597

592598
private BlockHash.GroupSpec makeGroupSpec() {
593-
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
599+
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, CATEGORIZE_DEF);
594600
}
595601

596602
private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) {

0 commit comments

Comments
 (0)