Skip to content

Commit fede927

Browse files
authored
[ML] A text categorization aggregation that works like ML categorization (#80867)
This PR adds a text categorization aggregation that uses the same approaches as the categorization feature of ML anomaly detection jobs.
1 parent ad985fe commit fede927

25 files changed

+81060
-1
lines changed

docs/changelog/80867.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 80867
2+
summary: A text categorization aggregation that works like ML categorization
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.integration;
9+
10+
import org.elasticsearch.action.admin.indices.create.CreateIndexRequest;
11+
import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
12+
import org.elasticsearch.action.admin.indices.stats.ShardStats;
13+
import org.elasticsearch.action.bulk.BulkRequestBuilder;
14+
import org.elasticsearch.action.index.IndexRequestBuilder;
15+
import org.elasticsearch.action.search.SearchResponse;
16+
import org.elasticsearch.cluster.metadata.IndexMetadata;
17+
import org.elasticsearch.cluster.routing.ShardRouting;
18+
import org.elasticsearch.common.settings.Settings;
19+
import org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder;
20+
import org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation;
21+
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
22+
23+
import java.util.Arrays;
24+
import java.util.HashSet;
25+
import java.util.List;
26+
import java.util.Map;
27+
import java.util.Set;
28+
import java.util.stream.Collectors;
29+
30+
import static org.hamcrest.Matchers.empty;
31+
import static org.hamcrest.Matchers.hasSize;
32+
import static org.hamcrest.Matchers.is;
33+
import static org.hamcrest.Matchers.notNullValue;
34+
35+
public class CategorizeTextDistributedIT extends BaseMlIntegTestCase {
36+
37+
/**
38+
* When categorizing text in a multi-node cluster the categorize_text2 aggregation has
39+
* a harder job than in a single node cluster. The categories must be serialized between
40+
* nodes and then merged appropriately on the receiving node. This test ensures that
41+
* this serialization and subsequent merging works in the same way that merging would work
42+
* on a single node.
43+
*/
44+
public void testDistributedCategorizeText() {
45+
internalCluster().ensureAtLeastNumDataNodes(3);
46+
ensureStableCluster();
47+
48+
// System indices may affect the distribution of shards of this index,
49+
// but it has so many that it should have shards on all the nodes
50+
String indexName = "data";
51+
CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings(
52+
Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, "9").put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, "0")
53+
);
54+
client().admin().indices().create(createIndexRequest).actionGet();
55+
56+
// Spread 10000 documents in 4 categories across the shards
57+
for (int i = 0; i < 10; ++i) {
58+
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
59+
for (int j = 0; j < 250; ++j) {
60+
IndexRequestBuilder indexRequestBuilder = client().prepareIndex(indexName)
61+
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol destroy"));
62+
bulkRequestBuilder.add(indexRequestBuilder);
63+
indexRequestBuilder = client().prepareIndex(indexName)
64+
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol init"));
65+
bulkRequestBuilder.add(indexRequestBuilder);
66+
indexRequestBuilder = client().prepareIndex(indexName)
67+
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol start"));
68+
bulkRequestBuilder.add(indexRequestBuilder);
69+
indexRequestBuilder = client().prepareIndex(indexName)
70+
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol stop"));
71+
bulkRequestBuilder.add(indexRequestBuilder);
72+
}
73+
bulkRequestBuilder.execute().actionGet();
74+
}
75+
client().admin().indices().prepareRefresh(indexName).execute().actionGet();
76+
77+
// Confirm the theory that all 3 nodes will have a shard on
78+
IndicesStatsResponse indicesStatsResponse = client().admin().indices().prepareStats(indexName).execute().actionGet();
79+
Set<String> nodesWithShards = Arrays.stream(indicesStatsResponse.getShards())
80+
.map(ShardStats::getShardRouting)
81+
.map(ShardRouting::currentNodeId)
82+
.collect(Collectors.toSet());
83+
assertThat(nodesWithShards, hasSize(internalCluster().size()));
84+
85+
SearchResponse searchResponse = client().prepareSearch(indexName)
86+
.addAggregation(new CategorizeTextAggregationBuilder("categories", "message"))
87+
.setSize(0)
88+
.execute()
89+
.actionGet();
90+
91+
InternalCategorizationAggregation aggregation = searchResponse.getAggregations().get("categories");
92+
assertThat(aggregation, notNullValue());
93+
94+
// We should have created 4 categories, one for each of the distinct messages we indexed, all with counts of 2500 (= 10000/4)
95+
List<InternalCategorizationAggregation.Bucket> buckets = aggregation.getBuckets();
96+
assertThat(buckets, notNullValue());
97+
assertThat(buckets, hasSize(4));
98+
Set<String> expectedLastTokens = new HashSet<>(List.of("destroy", "init", "start", "stop"));
99+
for (InternalCategorizationAggregation.Bucket bucket : buckets) {
100+
assertThat(bucket.getDocCount(), is(2500L));
101+
String[] tokens = bucket.getKeyAsString().split(" ");
102+
String lastToken = tokens[tokens.length - 1];
103+
assertThat(lastToken + " not found in " + expectedLastTokens, expectedLastTokens.remove(lastToken), is(true));
104+
}
105+
assertThat("Some expected last tokens not found " + expectedLastTokens, expectedLastTokens, empty());
106+
}
107+
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,16 @@ public List<AggregationSpec> getAggregations() {
14171417
CategorizeTextAggregationBuilder::new,
14181418
CategorizeTextAggregationBuilder.PARSER
14191419
).addResultReader(InternalCategorizationAggregation::new)
1420-
.setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME))
1420+
.setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)),
1421+
// TODO: in the long term only keep one or other of these categorization aggregations
1422+
new AggregationSpec(
1423+
org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.NAME,
1424+
org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder::new,
1425+
org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.PARSER
1426+
).addResultReader(org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation::new)
1427+
.setAggregatorRegistrar(
1428+
s -> s.registerUsage(org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.NAME)
1429+
)
14211430
);
14221431
}
14231432

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.aggs.categorization2;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.common.logging.LoggerMessageFormat;
12+
import org.elasticsearch.common.util.BytesRefHash;
13+
import org.elasticsearch.core.Releasable;
14+
import org.elasticsearch.search.aggregations.AggregationExecutionException;
15+
16+
class CategorizationBytesRefHash implements Releasable {
17+
18+
private final BytesRefHash bytesRefHash;
19+
20+
CategorizationBytesRefHash(BytesRefHash bytesRefHash) {
21+
this.bytesRefHash = bytesRefHash;
22+
}
23+
24+
int[] getIds(BytesRef[] tokens) {
25+
int[] ids = new int[tokens.length];
26+
for (int i = 0; i < tokens.length; i++) {
27+
ids[i] = put(tokens[i]);
28+
}
29+
return ids;
30+
}
31+
32+
BytesRef[] getDeeps(int[] ids) {
33+
BytesRef[] tokens = new BytesRef[ids.length];
34+
for (int i = 0; i < tokens.length; i++) {
35+
tokens[i] = getDeep(ids[i]);
36+
}
37+
return tokens;
38+
}
39+
40+
BytesRef getDeep(long id) {
41+
BytesRef shallow = bytesRefHash.get(id, new BytesRef());
42+
return BytesRef.deepCopyOf(shallow);
43+
}
44+
45+
int put(BytesRef bytesRef) {
46+
long hash = bytesRefHash.add(bytesRef);
47+
if (hash < 0) {
48+
// BytesRefHash returns -1 - hash if the entry already existed, but we just want to return the hash
49+
return (int) (-1L - hash);
50+
}
51+
if (hash > Integer.MAX_VALUE) {
52+
throw new AggregationExecutionException(
53+
LoggerMessageFormat.format(
54+
"more than [{}] unique terms encountered. "
55+
+ "Consider restricting the documents queried or adding [{}] in the {} configuration",
56+
Integer.MAX_VALUE,
57+
CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(),
58+
CategorizeTextAggregationBuilder.NAME
59+
)
60+
);
61+
}
62+
return (int) hash;
63+
}
64+
65+
@Override
66+
public void close() {
67+
bytesRefHash.close();
68+
}
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.aggs.categorization2;
9+
10+
import java.io.BufferedReader;
11+
import java.io.IOException;
12+
import java.io.InputStream;
13+
import java.io.InputStreamReader;
14+
import java.nio.charset.StandardCharsets;
15+
import java.util.HashMap;
16+
import java.util.Locale;
17+
import java.util.Map;
18+
import java.util.function.Function;
19+
import java.util.stream.Collectors;
20+
import java.util.stream.Stream;
21+
22+
/**
23+
* Port of the C++ class <a href="https://github.com/elastic/ml-cpp/blob/main/include/core/CWordDictionary.h">
24+
* <code>CWordDictionary</code></a>.
25+
*/
26+
public class CategorizationPartOfSpeechDictionary {
27+
28+
static final String DICTIONARY_FILE_PATH = "/org/elasticsearch/xpack/ml/aggs/categorization2/ml-en.dict";
29+
30+
static final String PART_OF_SPEECH_SEPARATOR = "@";
31+
32+
public enum PartOfSpeech {
33+
NOT_IN_DICTIONARY('\0'),
34+
UNKNOWN('?'),
35+
NOUN('N'),
36+
PLURAL('p'),
37+
VERB('V'),
38+
ADJECTIVE('A'),
39+
ADVERB('v'),
40+
CONJUNCTION('C'),
41+
PREPOSITION('P'),
42+
INTERJECTION('!'),
43+
PRONOUN('r'),
44+
DEFINITE_ARTICLE('D'),
45+
INDEFINITE_ARTICLE('I');
46+
47+
private final char code;
48+
49+
PartOfSpeech(char code) {
50+
this.code = code;
51+
}
52+
53+
char getCode() {
54+
return code;
55+
}
56+
57+
private static final Map<Character, PartOfSpeech> CODE_MAPPING =
58+
// 'h', 'o', 't', and 'i' are codes for specialist types of noun and verb that we don't distinguish
59+
Stream.concat(
60+
Map.of('h', NOUN, 'o', NOUN, 't', VERB, 'i', VERB).entrySet().stream(),
61+
Stream.of(PartOfSpeech.values()).collect(Collectors.toMap(PartOfSpeech::getCode, Function.identity())).entrySet().stream()
62+
)
63+
.collect(
64+
Collectors.toUnmodifiableMap(Map.Entry<Character, PartOfSpeech>::getKey, Map.Entry<Character, PartOfSpeech>::getValue)
65+
);
66+
67+
static PartOfSpeech fromCode(char partOfSpeechCode) {
68+
PartOfSpeech pos = CODE_MAPPING.get(partOfSpeechCode);
69+
if (pos == null) {
70+
throw new IllegalArgumentException("Unknown part-of-speech code [" + partOfSpeechCode + "]");
71+
}
72+
return pos;
73+
}
74+
}
75+
76+
/**
77+
* Lazy loaded singleton instance to avoid loading the dictionary repeatedly.
78+
*/
79+
private static CategorizationPartOfSpeechDictionary instance;
80+
private static final Object INIT_LOCK = new Object();
81+
82+
/**
83+
* Keys are lower case.
84+
*/
85+
private final Map<String, PartOfSpeech> partOfSpeechDictionary = new HashMap<>();
86+
private final int maxDictionaryWordLength;
87+
88+
CategorizationPartOfSpeechDictionary(InputStream is) throws IOException {
89+
90+
int maxLength = 0;
91+
BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8));
92+
String line;
93+
while ((line = reader.readLine()) != null) {
94+
line = line.trim();
95+
if (line.isEmpty()) {
96+
continue;
97+
}
98+
String[] split = line.split(PART_OF_SPEECH_SEPARATOR);
99+
if (split.length != 2) {
100+
throw new IllegalArgumentException(
101+
"Unexpected format in line [" + line + "]: expected one [" + PART_OF_SPEECH_SEPARATOR + "] separator"
102+
);
103+
}
104+
if (split[0].isEmpty()) {
105+
throw new IllegalArgumentException(
106+
"Unexpected format in line [" + line + "]: nothing preceding [" + PART_OF_SPEECH_SEPARATOR + "] separator"
107+
);
108+
}
109+
if (split[1].isEmpty()) {
110+
throw new IllegalArgumentException(
111+
"Unexpected format in line [" + line + "]: nothing following [" + PART_OF_SPEECH_SEPARATOR + "] separator"
112+
);
113+
}
114+
String lowerCaseWord = split[0].toLowerCase(Locale.ROOT);
115+
partOfSpeechDictionary.put(lowerCaseWord, PartOfSpeech.fromCode(split[1].charAt(0)));
116+
maxLength = Math.max(maxLength, lowerCaseWord.length());
117+
}
118+
maxDictionaryWordLength = maxLength;
119+
}
120+
121+
// TODO: now we have this in Java, perform this operation in Java for anomaly detection categorization instead of in C++.
122+
// (It could maybe be incorporated into the categorization analyzer and then shared between aggregation and anomaly detection.)
123+
/**
124+
* Find the part of speech (noun, verb, adjective, etc.) for a supplied word.
125+
* @return Which part of speech does the supplied word represent? {@link PartOfSpeech#NOT_IN_DICTIONARY} is returned
126+
* for words that aren't in the dictionary at all.
127+
*/
128+
public PartOfSpeech getPartOfSpeech(String word) {
129+
if (word.length() > maxDictionaryWordLength) {
130+
return PartOfSpeech.NOT_IN_DICTIONARY;
131+
}
132+
// This is quite slow as it creates a new string for every lookup. However, experiments show
133+
// that trying to do case-insensitive comparisons instead of creating a lower case string is
134+
// even slower.
135+
return partOfSpeechDictionary.getOrDefault(word.toLowerCase(Locale.ROOT), PartOfSpeech.NOT_IN_DICTIONARY);
136+
}
137+
138+
/**
139+
* @return Is the supplied word in the dictionary?
140+
*/
141+
public boolean isInDictionary(String word) {
142+
return getPartOfSpeech(word) != PartOfSpeech.NOT_IN_DICTIONARY;
143+
}
144+
145+
public static CategorizationPartOfSpeechDictionary getInstance() throws IOException {
146+
if (instance != null) {
147+
return instance;
148+
}
149+
synchronized (INIT_LOCK) {
150+
if (instance == null) {
151+
try (InputStream is = CategorizationPartOfSpeechDictionary.class.getResourceAsStream(DICTIONARY_FILE_PATH)) {
152+
instance = new CategorizationPartOfSpeechDictionary(is);
153+
}
154+
}
155+
return instance;
156+
}
157+
}
158+
}

0 commit comments

Comments
 (0)