Skip to content

Commit ebc5eca

Browse files
authored
[ML] add nlp config update serialization tests (#85867)
adds inference request serialization tests that include nlp config updates. relates: #85863
1 parent fede927 commit ebc5eca

File tree

7 files changed

+173
-85
lines changed

7 files changed

+173
-85
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,22 @@
1414
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
1515
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
1616
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdateTests;
17+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
18+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdateTests;
1719
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
20+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
21+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdateTests;
22+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfigUpdate;
23+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
24+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdateTests;
1825
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
1926
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdateTests;
27+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate;
28+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdateTests;
29+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
30+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdateTests;
31+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
32+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests;
2033

2134
import java.util.ArrayList;
2235
import java.util.List;
@@ -44,6 +57,12 @@ private static InferenceConfigUpdate randomInferenceConfigUpdate() {
4457
RegressionConfigUpdateTests.randomRegressionConfigUpdate(),
4558
ClassificationConfigUpdateTests.randomClassificationConfigUpdate(),
4659
ResultsFieldUpdateTests.randomUpdate(),
60+
TextClassificationConfigUpdateTests.randomUpdate(),
61+
TextEmbeddingConfigUpdateTests.randomUpdate(),
62+
NerConfigUpdateTests.randomUpdate(),
63+
FillMaskConfigUpdateTests.randomUpdate(),
64+
ZeroShotClassificationConfigUpdateTests.randomUpdate(),
65+
PassThroughConfigUpdateTests.randomUpdate(),
4766
EmptyConfigUpdateTests.testInstance()
4867
);
4968
}
@@ -68,6 +87,27 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() {
6887

6988
@Override
7089
protected Request mutateInstanceForVersion(Request instance, Version version) {
71-
return instance;
90+
InferenceConfigUpdate adjustedUpdate;
91+
InferenceConfigUpdate currentUpdate = instance.getUpdate();
92+
if (currentUpdate instanceof NlpConfigUpdate nlpConfigUpdate) {
93+
if (nlpConfigUpdate instanceof TextClassificationConfigUpdate update) {
94+
adjustedUpdate = TextClassificationConfigUpdateTests.mutateForVersion(update, version);
95+
} else if (nlpConfigUpdate instanceof TextEmbeddingConfigUpdate update) {
96+
adjustedUpdate = TextEmbeddingConfigUpdateTests.mutateForVersion(update, version);
97+
} else if (nlpConfigUpdate instanceof NerConfigUpdate update) {
98+
adjustedUpdate = NerConfigUpdateTests.mutateForVersion(update, version);
99+
} else if (nlpConfigUpdate instanceof FillMaskConfigUpdate update) {
100+
adjustedUpdate = FillMaskConfigUpdateTests.mutateForVersion(update, version);
101+
} else if (nlpConfigUpdate instanceof ZeroShotClassificationConfigUpdate update) {
102+
adjustedUpdate = ZeroShotClassificationConfigUpdateTests.mutateForVersion(update, version);
103+
} else if (nlpConfigUpdate instanceof PassThroughConfigUpdate update) {
104+
adjustedUpdate = PassThroughConfigUpdateTests.mutateForVersion(update, version);
105+
} else {
106+
throw new IllegalArgumentException("Unknown update [" + currentUpdate.getName() + "]");
107+
}
108+
} else {
109+
adjustedUpdate = currentUpdate;
110+
}
111+
return new Request(instance.getModelId(), instance.getObjectsToInfer(), adjustedUpdate, instance.isPreviouslyLicensed());
72112
}
73113
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,27 @@
2525

2626
public class FillMaskConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<FillMaskConfigUpdate> {
2727

28+
public static FillMaskConfigUpdate randomUpdate() {
29+
FillMaskConfigUpdate.Builder builder = new FillMaskConfigUpdate.Builder();
30+
if (randomBoolean()) {
31+
builder.setNumTopClasses(randomIntBetween(1, 4));
32+
}
33+
if (randomBoolean()) {
34+
builder.setResultsField(randomAlphaOfLength(8));
35+
}
36+
if (randomBoolean()) {
37+
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
38+
}
39+
return builder.build();
40+
}
41+
42+
public static FillMaskConfigUpdate mutateForVersion(FillMaskConfigUpdate instance, Version version) {
43+
if (version.before(Version.V_8_1_0)) {
44+
return new FillMaskConfigUpdate(instance.getNumTopClasses(), instance.getResultsField(), null);
45+
}
46+
return instance;
47+
}
48+
2849
@Override
2950
Tuple<Map<String, Object>, FillMaskConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
3051
int topClasses = randomIntBetween(1, 10);
@@ -103,25 +124,12 @@ protected Writeable.Reader<FillMaskConfigUpdate> instanceReader() {
103124

104125
@Override
105126
protected FillMaskConfigUpdate createTestInstance() {
106-
FillMaskConfigUpdate.Builder builder = new FillMaskConfigUpdate.Builder();
107-
if (randomBoolean()) {
108-
builder.setNumTopClasses(randomIntBetween(1, 4));
109-
}
110-
if (randomBoolean()) {
111-
builder.setResultsField(randomAlphaOfLength(8));
112-
}
113-
if (randomBoolean()) {
114-
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
115-
}
116-
return builder.build();
127+
return randomUpdate();
117128
}
118129

119130
@Override
120131
protected FillMaskConfigUpdate mutateInstanceForVersion(FillMaskConfigUpdate instance, Version version) {
121-
if (version.before(Version.V_8_1_0)) {
122-
return new FillMaskConfigUpdate(instance.getNumTopClasses(), instance.getResultsField(), null);
123-
}
124-
return instance;
132+
return mutateForVersion(instance, version);
125133
}
126134

127135
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@
2626

2727
public class NerConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<NerConfigUpdate> {
2828

29+
public static NerConfigUpdate randomUpdate() {
30+
NerConfigUpdate.Builder builder = new NerConfigUpdate.Builder();
31+
if (randomBoolean()) {
32+
builder.setResultsField(randomAlphaOfLength(8));
33+
}
34+
if (randomBoolean()) {
35+
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
36+
}
37+
return builder.build();
38+
}
39+
40+
public static NerConfigUpdate mutateForVersion(NerConfigUpdate instance, Version version) {
41+
if (version.before(Version.V_8_1_0)) {
42+
return new NerConfigUpdate(instance.getResultsField(), null);
43+
}
44+
return instance;
45+
}
46+
2947
@Override
3048
Tuple<Map<String, Object>, NerConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
3149
NerConfigUpdate expected = new NerConfigUpdate("ml-results", expectedTokenization);
@@ -86,22 +104,12 @@ protected Writeable.Reader<NerConfigUpdate> instanceReader() {
86104

87105
@Override
88106
protected NerConfigUpdate createTestInstance() {
89-
NerConfigUpdate.Builder builder = new NerConfigUpdate.Builder();
90-
if (randomBoolean()) {
91-
builder.setResultsField(randomAlphaOfLength(8));
92-
}
93-
if (randomBoolean()) {
94-
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
95-
}
96-
return builder.build();
107+
return randomUpdate();
97108
}
98109

99110
@Override
100111
protected NerConfigUpdate mutateInstanceForVersion(NerConfigUpdate instance, Version version) {
101-
if (version.before(Version.V_8_1_0)) {
102-
return new NerConfigUpdate(instance.getResultsField(), null);
103-
}
104-
return instance;
112+
return mutateForVersion(instance, version);
105113
}
106114

107115
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@
2626

2727
public class PassThroughConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<PassThroughConfigUpdate> {
2828

29+
public static PassThroughConfigUpdate randomUpdate() {
30+
PassThroughConfigUpdate.Builder builder = new PassThroughConfigUpdate.Builder();
31+
if (randomBoolean()) {
32+
builder.setResultsField(randomAlphaOfLength(8));
33+
}
34+
if (randomBoolean()) {
35+
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
36+
}
37+
return builder.build();
38+
}
39+
40+
public static PassThroughConfigUpdate mutateForVersion(PassThroughConfigUpdate instance, Version version) {
41+
if (version.before(Version.V_8_1_0)) {
42+
return new PassThroughConfigUpdate(instance.getResultsField(), null);
43+
}
44+
return instance;
45+
}
46+
2947
@Override
3048
Tuple<Map<String, Object>, PassThroughConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
3149
PassThroughConfigUpdate expected = new PassThroughConfigUpdate("ml-results", expectedTokenization);
@@ -76,22 +94,12 @@ protected Writeable.Reader<PassThroughConfigUpdate> instanceReader() {
7694

7795
@Override
7896
protected PassThroughConfigUpdate createTestInstance() {
79-
PassThroughConfigUpdate.Builder builder = new PassThroughConfigUpdate.Builder();
80-
if (randomBoolean()) {
81-
builder.setResultsField(randomAlphaOfLength(8));
82-
}
83-
if (randomBoolean()) {
84-
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
85-
}
86-
return builder.build();
97+
return randomUpdate();
8798
}
8899

89100
@Override
90101
protected PassThroughConfigUpdate mutateInstanceForVersion(PassThroughConfigUpdate instance, Version version) {
91-
if (version.before(Version.V_8_1_0)) {
92-
return new PassThroughConfigUpdate(instance.getResultsField(), null);
93-
}
94-
return instance;
102+
return mutateForVersion(instance, version);
95103
}
96104

97105
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,35 @@
2828

2929
public class TextClassificationConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<TextClassificationConfigUpdate> {
3030

31+
public static TextClassificationConfigUpdate randomUpdate() {
32+
TextClassificationConfigUpdate.Builder builder = new TextClassificationConfigUpdate.Builder();
33+
if (randomBoolean()) {
34+
builder.setNumTopClasses(randomIntBetween(1, 4));
35+
}
36+
if (randomBoolean()) {
37+
builder.setClassificationLabels(randomList(1, 3, () -> randomAlphaOfLength(4)));
38+
}
39+
if (randomBoolean()) {
40+
builder.setResultsField(randomAlphaOfLength(8));
41+
}
42+
if (randomBoolean()) {
43+
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
44+
}
45+
return builder.build();
46+
}
47+
48+
public static TextClassificationConfigUpdate mutateForVersion(TextClassificationConfigUpdate instance, Version version) {
49+
if (version.before(Version.V_8_1_0)) {
50+
return new TextClassificationConfigUpdate(
51+
instance.getClassificationLabels(),
52+
instance.getNumTopClasses(),
53+
instance.getResultsField(),
54+
null
55+
);
56+
}
57+
return instance;
58+
}
59+
3160
@Override
3261
Tuple<Map<String, Object>, TextClassificationConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
3362
int numClasses = randomIntBetween(1, 10);
@@ -159,33 +188,12 @@ protected Writeable.Reader<TextClassificationConfigUpdate> instanceReader() {
159188

160189
@Override
161190
protected TextClassificationConfigUpdate createTestInstance() {
162-
TextClassificationConfigUpdate.Builder builder = new TextClassificationConfigUpdate.Builder();
163-
if (randomBoolean()) {
164-
builder.setNumTopClasses(randomIntBetween(1, 4));
165-
}
166-
if (randomBoolean()) {
167-
builder.setClassificationLabels(randomList(1, 3, () -> randomAlphaOfLength(4)));
168-
}
169-
if (randomBoolean()) {
170-
builder.setResultsField(randomAlphaOfLength(8));
171-
}
172-
if (randomBoolean()) {
173-
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
174-
}
175-
return builder.build();
191+
return randomUpdate();
176192
}
177193

178194
@Override
179195
protected TextClassificationConfigUpdate mutateInstanceForVersion(TextClassificationConfigUpdate instance, Version version) {
180-
if (version.before(Version.V_8_1_0)) {
181-
return new TextClassificationConfigUpdate(
182-
instance.getClassificationLabels(),
183-
instance.getNumTopClasses(),
184-
instance.getResultsField(),
185-
null
186-
);
187-
}
188-
return instance;
196+
return mutateForVersion(instance, version);
189197
}
190198

191199
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@
2626

2727
public class TextEmbeddingConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<TextEmbeddingConfigUpdate> {
2828

29+
public static TextEmbeddingConfigUpdate randomUpdate() {
30+
TextEmbeddingConfigUpdate.Builder builder = new TextEmbeddingConfigUpdate.Builder();
31+
if (randomBoolean()) {
32+
builder.setResultsField(randomAlphaOfLength(8));
33+
}
34+
if (randomBoolean()) {
35+
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
36+
}
37+
return builder.build();
38+
}
39+
40+
public static TextEmbeddingConfigUpdate mutateForVersion(TextEmbeddingConfigUpdate instance, Version version) {
41+
if (version.before(Version.V_8_1_0)) {
42+
return new TextEmbeddingConfigUpdate(instance.getResultsField(), null);
43+
}
44+
return instance;
45+
}
46+
2947
@Override
3048
Tuple<Map<String, Object>, TextEmbeddingConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
3149
TextEmbeddingConfigUpdate expected = new TextEmbeddingConfigUpdate("ml-results", expectedTokenization);
@@ -76,22 +94,12 @@ protected Writeable.Reader<TextEmbeddingConfigUpdate> instanceReader() {
7694

7795
@Override
7896
protected TextEmbeddingConfigUpdate createTestInstance() {
79-
TextEmbeddingConfigUpdate.Builder builder = new TextEmbeddingConfigUpdate.Builder();
80-
if (randomBoolean()) {
81-
builder.setResultsField(randomAlphaOfLength(8));
82-
}
83-
if (randomBoolean()) {
84-
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
85-
}
86-
return builder.build();
97+
return randomUpdate();
8798
}
8899

89100
@Override
90101
protected TextEmbeddingConfigUpdate mutateInstanceForVersion(TextEmbeddingConfigUpdate instance, Version version) {
91-
if (version.before(Version.V_8_1_0)) {
92-
return new TextEmbeddingConfigUpdate(instance.getResultsField(), null);
93-
}
94-
return instance;
102+
return mutateForVersion(instance, version);
95103
}
96104

97105
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@
2727

2828
public class ZeroShotClassificationConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<ZeroShotClassificationConfigUpdate> {
2929

30+
public static ZeroShotClassificationConfigUpdate randomUpdate() {
31+
return new ZeroShotClassificationConfigUpdate(
32+
randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),
33+
randomBoolean() ? null : randomBoolean(),
34+
randomBoolean() ? null : randomAlphaOfLength(5),
35+
randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)
36+
);
37+
}
38+
39+
public static ZeroShotClassificationConfigUpdate mutateForVersion(ZeroShotClassificationConfigUpdate instance, Version version) {
40+
if (version.before(Version.V_8_1_0)) {
41+
return new ZeroShotClassificationConfigUpdate(instance.getLabels(), instance.getMultiLabel(), instance.getResultsField(), null);
42+
}
43+
return instance;
44+
}
45+
3046
@Override
3147
protected boolean supportsUnknownFields() {
3248
return false;
@@ -49,10 +65,7 @@ protected ZeroShotClassificationConfigUpdate createTestInstance() {
4965

5066
@Override
5167
protected ZeroShotClassificationConfigUpdate mutateInstanceForVersion(ZeroShotClassificationConfigUpdate instance, Version version) {
52-
if (version.before(Version.V_8_1_0)) {
53-
return new ZeroShotClassificationConfigUpdate(instance.getLabels(), instance.getMultiLabel(), instance.getResultsField(), null);
54-
}
55-
return instance;
68+
return mutateForVersion(instance, version);
5669
}
5770

5871
@Override
@@ -197,12 +210,7 @@ public void testIsNoop() {
197210
}
198211

199212
public static ZeroShotClassificationConfigUpdate createRandom() {
200-
return new ZeroShotClassificationConfigUpdate(
201-
randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),
202-
randomBoolean() ? null : randomBoolean(),
203-
randomBoolean() ? null : randomAlphaOfLength(5),
204-
randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)
205-
);
213+
return randomUpdate();
206214
}
207215

208216
@Override

0 commit comments

Comments
 (0)