Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion x-pack/plugin/esql/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ esplugin {
name 'x-pack-esql'
description 'The plugin that powers ESQL for Elasticsearch'
classname 'org.elasticsearch.xpack.esql.plugin.EsqlPlugin'
extendedPlugins = ['x-pack-esql-core', 'lang-painless']
extendedPlugins = ['x-pack-esql-core', 'lang-painless', 'x-pack-ml']
}

base {
Expand All @@ -22,6 +22,7 @@ dependencies {
compileOnly project(path: xpackModule('core'))
compileOnly project(':modules:lang-painless:spi')
compileOnly project(xpackModule('esql-core'))
compileOnly project(xpackModule('ml'))
implementation project('compute')
implementation project('compute:ann')
implementation project(':libs:elasticsearch-dissect')
Expand Down
3 changes: 3 additions & 0 deletions x-pack/plugin/esql/compute/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ base {
dependencies {
compileOnly project(':server')
compileOnly project('ann')
compileOnly project(xpackModule('ml'))
annotationProcessor project('gen')
implementation 'com.carrotsearch:hppc:0.8.1'

testImplementation project(':test:framework')
testImplementation(project(xpackModule('esql-core')))
testImplementation(project(xpackModule('core')))
testImplementation(project(xpackModule('ml')))
}

def projectDirectory = project.layout.projectDirectory
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/esql/compute/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

module org.elasticsearch.compute {

requires org.apache.lucene.analysis.common;
requires org.apache.lucene.core;
requires org.elasticsearch.base;
requires org.elasticsearch.server;
Expand All @@ -15,6 +16,7 @@
// required due to dependency on org.elasticsearch.common.util.concurrent.AbstractAsyncTask
requires org.apache.logging.log4j;
requires org.elasticsearch.logging;
requires org.elasticsearch.ml;
requires org.elasticsearch.tdigest;
requires org.elasticsearch.geo;
requires hppc;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
categorize
required_capability: categorize

FROM sample_data
| SORT message ASC
| STATS count=COUNT(), values=MV_SORT(VALUES(message)) BY category=CATEGORIZE(message)
| SORT category
;

count:long | values:keyword | category:integer
3 | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0
3 | [Connection error] | 1
1 | [Disconnected] | 2
;

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,12 @@ public enum Cap {
/**
* Support explicit casting from string literal to DATE_PERIOD or TIME_DURATION.
*/
CAST_STRING_LITERAL_TO_TEMPORAL_AMOUNT;
CAST_STRING_LITERAL_TO_TEMPORAL_AMOUNT,

/**
* Supported the text categorization function "CATEGORIZE".
*/
CATEGORIZE(true);

private final boolean snapshotOnly;
private final FeatureFlag featureFlag;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -383,7 +384,10 @@ private FunctionDefinition[][] functions() {
}

private static FunctionDefinition[][] snapshotFunctions() {
return new FunctionDefinition[][] { new FunctionDefinition[] { def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } };
return new FunctionDefinition[][] {
new FunctionDefinition[] {
def(Categorize.class, Categorize::new, "categorize"),
def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.grouping;

import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.ann.Evaluator;
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.CustomAnalyzer;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.xpack.esql.capabilities.Validatable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

import java.io.IOException;
import java.util.List;
import java.util.function.Function;

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;

/**
* Categorizes text messages.
*
* This implementation is incomplete and comes with the following caveats:
* - it only works correctly on a single node.
* - when running on multiple nodes, category IDs of the different nodes are
* aggregated, even though the same ID can correspond to a totally different
* category
* - the output consists of category IDs, which should be replaced by category
* regexes or keys
*
* TODO(jan, nik): fix this
*/
public class Categorize extends GroupingFunction implements Validatable {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"Categorize",
Categorize::new
);

private final Expression field;

@FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages")
public Categorize(
Source source,
@Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
) {
super(source, List.of(field));
this.field = field;
}

private Categorize(StreamInput in) throws IOException {
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(field);
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

@Override
public boolean foldable() {
return field.foldable();
}

@Evaluator
static int process(
BytesRef v,
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
) {
String s = v.utf8ToString();
try (TokenStream ts = analyzer.tokenStream("text", s)) {
return categorizer.computeCategory(ts, s.length(), 1).getId();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public ExpressionEvaluator.Factory toEvaluator(Function<Expression, ExpressionEvaluator.Factory> toEvaluator) {
return new CategorizeEvaluator.Factory(
source(),
toEvaluator.apply(field),
context -> new CategorizationAnalyzer(
// TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer
new CustomAnalyzer(
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
new CharFilterFactory[0],
new TokenFilterFactory[0]
),
true
),
context -> new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
0.70f
)
);
}

@Override
protected TypeResolution resolveType() {
return isString(field(), sourceText(), DEFAULT);
}

@Override
public DataType dataType() {
return DataType.INTEGER;
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new Categorize(source(), newChildren.get(0));
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Categorize::new, field);
}

public Expression field() {
return field;
}

@Override
public String toString() {
return "Categorize{field=" + field + "}";
}
}
Loading