Skip to content

Commit 40bb62e

Browse files
authored
[ML] Fix ZeroShotClassificationConfig update mixing labels (#82848)
1 parent 2b3e41b commit 40bb62e

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public InferenceConfig apply(InferenceConfig originalConfig) {
152152
}
153153

154154
boolean isNoop(ZeroShotClassificationConfig originalConfig) {
155-
return (labels == null || labels.equals(originalConfig.getClassificationLabels()))
155+
return (labels == null || labels.equals(originalConfig.getLabels()))
156156
&& (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()))
157157
&& (resultsField == null || resultsField.equals(originalConfig.getResultsField()))
158158
&& super.isNoop();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,30 @@ public void testApplyWithEmptyLabelsInConfigAndUpdate() {
178178

179179
public void testIsNoop() {
180180
assertTrue(new ZeroShotClassificationConfigUpdate.Builder().build().isNoop(ZeroShotClassificationConfigTests.createRandom()));
181+
182+
var originalConfig = new ZeroShotClassificationConfig(
183+
List.of("contradiction", "neutral", "entailment"),
184+
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
185+
randomBoolean() ? null : BertTokenizationTests.createRandom(),
186+
randomAlphaOfLength(10),
187+
randomBoolean(),
188+
null,
189+
randomBoolean() ? null : randomAlphaOfLength(8)
190+
);
191+
192+
var update = new ZeroShotClassificationConfigUpdate.Builder().setLabels(List.of("glad", "sad", "mad")).build();
193+
assertFalse(update.isNoop(originalConfig));
194+
195+
originalConfig = new ZeroShotClassificationConfig(
196+
List.of("contradiction", "neutral", "entailment"),
197+
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
198+
randomBoolean() ? null : BertTokenizationTests.createRandom(),
199+
randomAlphaOfLength(10),
200+
randomBoolean(),
201+
List.of("glad", "sad", "mad"),
202+
randomBoolean() ? null : randomAlphaOfLength(8)
203+
);
204+
assertTrue(update.isNoop(originalConfig));
181205
}
182206

183207
public static ZeroShotClassificationConfigUpdate createRandom() {

0 commit comments

Comments
 (0)