Skip to content

Commit 9983298

Browse files
authored
[ML] fix LangIdent model when multiple unicode scripts are present (#81876) (#81890)
LangIdent was recently updated to handle multiple unicode scripts (#80675). But this introduced some bugs fixed with this commit. 1. Sections with the same scripted were weighted by Java string length (utf-16) encoding. This is not accurate as certain languages (like Chinese and Korean) convey much more information with fewer utf-16 characters. FIX weight by utf-8 length. 2. The weighing of different language scores was done via the raw score from the neural network. This caused languages with a high score (but low compared to most likely language) from the network to be inaccurately weighted. FIX We are now instead weighing the probabilities of the sections of the text. 3. To split the input across the multiple scripts, we split on the "paired down" CDL3 script types. Java has superior support for unicode script blocks. FIX split by Java unicode script blocks not by the paired down CDL3 scripts
1 parent 74539ed commit 9983298

File tree

3 files changed

+57
-19
lines changed

3 files changed

+57
-19
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
2323
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
2424
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
25-
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptCode;
2625
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
2726

2827
import java.io.IOException;
28+
import java.nio.charset.StandardCharsets;
2929
import java.util.ArrayList;
3030
import java.util.Arrays;
3131
import java.util.Collections;
@@ -46,16 +46,16 @@
4646
public class CustomWordEmbedding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
4747

4848
public static class StringLengthAndEmbedding {
49-
final int stringLen;
49+
final int utf8StringLen;
5050
final double[] embedding;
5151

52-
public StringLengthAndEmbedding(int stringLen, double[] embedding) {
53-
this.stringLen = stringLen;
52+
public StringLengthAndEmbedding(int utf8StringLen, double[] embedding) {
53+
this.utf8StringLen = utf8StringLen;
5454
this.embedding = embedding;
5555
}
5656

57-
public int getStringLen() {
58-
return stringLen;
57+
public int getUtf8StringLen() {
58+
return utf8StringLen;
5959
}
6060

6161
public double[] getEmbedding() {
@@ -282,7 +282,7 @@ public void process(Map<String, Object> fields) {
282282
if (i >= codePoints.length) {
283283
break;
284284
}
285-
ScriptCode currentCode = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[i]));
285+
Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]);
286286
int j = i + 1;
287287
for (; j < codePoints.length; j++) {
288288
while (j < codePoints.length && Character.isLetter(codePoints[j]) == false) {
@@ -291,11 +291,11 @@ public void process(Map<String, Object> fields) {
291291
if (j >= codePoints.length) {
292292
break;
293293
}
294-
ScriptCode j1 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j]));
295-
if (j1 != currentCode && j1 != ScriptCode.Inherited) {
294+
Character.UnicodeScript j1 = Character.UnicodeScript.of(codePoints[j]);
295+
if (j1 != currentCode && j1 != Character.UnicodeScript.INHERITED) {
296296
if (j < codePoints.length - 1) {
297-
ScriptCode j2 = ScriptCode.unicodeScriptToULScript(Character.UnicodeScript.of(codePoints[j + 1]));
298-
if (j2 != ScriptCode.Common && j2 != currentCode) {
297+
Character.UnicodeScript j2 = Character.UnicodeScript.of(codePoints[j + 1]);
298+
if (j2 != Character.UnicodeScript.COMMON && j2 != currentCode) {
299299
break;
300300
}
301301
}
@@ -314,7 +314,11 @@ public void process(Map<String, Object> fields) {
314314
embeddings.add(
315315
new StringLengthAndEmbedding(
316316
// Don't count white spaces as bytes for the prediction
317-
str.trim().length(),
317+
// We ues utf-8 length here as
318+
// * The original C++ implementation does this when measuring string length
319+
// * Languages with complex characters (like zh) convey more information per a single utf-16 character and
320+
// using utf-8 length captures that.
321+
str.trim().getBytes(StandardCharsets.UTF_8).length,
318322
concatEmbeddings(
319323
FEATURE_EXTRACTORS.stream()
320324
.map((featureExtractor) -> featureExtractor.extractFeatures(builder.toString()))

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,23 +230,22 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
230230
);
231231
}
232232
List<?> embeddedVector = (List<?>) vector;
233-
double[] scores = new double[LANGUAGE_NAMES.size()];
233+
double[] probabilities = new double[LANGUAGE_NAMES.size()];
234234
int totalLen = 0;
235235
for (Object vec : embeddedVector) {
236236
if (vec instanceof CustomWordEmbedding.StringLengthAndEmbedding == false) {
237237
continue;
238238
}
239239
CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) vec;
240-
int square = stringLengthAndEmbedding.getStringLen() * stringLengthAndEmbedding.getStringLen();
240+
int square = stringLengthAndEmbedding.getUtf8StringLen() * stringLengthAndEmbedding.getUtf8StringLen();
241241
totalLen += square;
242242
double[] h0 = hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
243243
double[] score = softmaxLayer.productPlusBias(true, h0);
244-
sumDoubleArrays(scores, score, Math.max(square, 1));
244+
sumDoubleArrays(probabilities, softMax(score), Math.max(square, 1));
245245
}
246246
if (totalLen != 0) {
247-
divMut(scores, totalLen);
247+
divMut(probabilities, totalLen);
248248
}
249-
double[] probabilities = softMax(scores);
250249
ClassificationConfig classificationConfig = (ClassificationConfig) config;
251250
Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
252251
probabilities,

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import static org.hamcrest.CoreMatchers.equalTo;
3131
import static org.hamcrest.Matchers.closeTo;
32-
import static org.hamcrest.Matchers.greaterThan;
3332
import static org.mockito.Mockito.mock;
3433

3534
public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
@@ -103,6 +102,12 @@ public void testMixedLangInference() throws Exception {
103102
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(inferenceObj("이Q현"), classificationConfig);
104103
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
105104

105+
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
106+
inferenceObj("매트 스미스는 BBC äôs Doctor Who를 그만둔다."),
107+
classificationConfig
108+
);
109+
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
110+
106111
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
107112
inferenceObj(
108113
"@#$%^&*(행 레이블 Dashboard ISSUE Qual. Plan Qual. !@#$%^&*() Report Qual."
@@ -112,6 +117,34 @@ public void testMixedLangInference() throws Exception {
112117
);
113118
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
114119

120+
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
121+
inferenceObj(
122+
"김걸도혁(金乞都革) 김공소(金公疎) 김교합(金咬哈) 김다롱합(金多弄哈) 김마상개(金麻尙介) 김우리개(金于里介) 김상미(金尙美) 김아도을치(金阿都乙赤) "
123+
+ "김아라(金阿喇) 김아랑합(金阿郞哈) 김아을가(金阿乙加) 김역류(金易留) 김우두(金于豆) 김우허내(金右虛乃) 김유리가(金留里加) 김윤적(金允績) "
124+
+ "김이랑합(金伊郞哈) 김인을개(金引乙介) 김입성(金入成) 김주창개(金主昌介) 김지하리(金之下里) 김차독(金箚禿) 김지칭가(金只稱哥) 김자라노(金者羅老)."
125+
),
126+
classificationConfig
127+
);
128+
// Half the string is ko the other half is zh
129+
assertThat(singleValueInferenceResults.valueAsString(), equalTo("ko"));
130+
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.5, 0.1));
131+
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("zh"));
132+
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.5, 0.1));
133+
134+
singleValueInferenceResults = (ClassificationInferenceResults) inferenceDefinition.infer(
135+
inferenceObj(
136+
"[ Republic of Korea ],\n"
137+
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
138+
+ "วันนี้ - ตัวอย่างนี้เป็นภาษาไทย\n"
139+
+ " !대한민국(, 영어: Republic of Korea, KOR)은 동아시아의 한반도 남부에 자리한 민주공화국이다. 서쪽으로 중화인민공화국과 황해를 사이에 두고"
140+
),
141+
classificationConfig
142+
);
143+
// Majority of the text is obviously Thai, but a close second is Korean
144+
assertThat(singleValueInferenceResults.valueAsString(), equalTo("th"));
145+
assertThat(singleValueInferenceResults.getPredictionScore(), closeTo(0.6, 0.1));
146+
assertThat(singleValueInferenceResults.getTopClasses().get(1).getClassification(), equalTo("ko"));
147+
assertThat(singleValueInferenceResults.getTopClasses().get(1).getScore(), closeTo(0.4, 0.1));
115148
}
116149

117150
public void testLangInference() throws Exception {
@@ -131,7 +164,9 @@ public void testLangInference() throws Exception {
131164
);
132165

133166
assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
134-
Matcher<Double> matcher = entry.getLanguage().equals("hr") ? greaterThan(cld3Probability) : closeTo(cld3Probability, .00001);
167+
// The stored language example is a mixture of `ja` and other languages, it should not be predicted with 1.0 accuracy as the
168+
// cld3 probability indicates.
169+
Matcher<Double> matcher = entry.getLanguage().equals("ja") ? closeTo(cld3Probability, 0.11) : closeTo(cld3Probability, .01);
135170
assertThat(
136171
"mismatch probability for language " + cld3Actual,
137172
singleValueInferenceResults.getTopClasses().get(0).getProbability(),

0 commit comments

Comments
 (0)