Skip to content

Commit 9441a74

Browse files
committed
Updated SemanticInferenceMetadataFieldsRecoveryTests to not use cosine similarity
1 parent 3e94ab2 commit 9441a74

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.index.translog.Translog;
2929
import org.elasticsearch.inference.ChunkedInference;
3030
import org.elasticsearch.inference.Model;
31+
import org.elasticsearch.inference.SimilarityMeasure;
3132
import org.elasticsearch.inference.TaskType;
3233
import org.elasticsearch.plugins.MapperPlugin;
3334
import org.elasticsearch.xcontent.XContentBuilder;
@@ -54,7 +55,7 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase
5455
private final boolean useIncludesExcludes;
5556

5657
public SemanticInferenceMetadataFieldsRecoveryTests(boolean useSynthetic, boolean useIncludesExcludes) {
57-
this.model1 = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING);
58+
this.model1 = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.COSINE));
5859
this.model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
5960
this.useSynthetic = useSynthetic;
6061
this.useIncludesExcludes = useIncludesExcludes;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
import org.elasticsearch.xpack.inference.services.ServiceUtils;
2727

2828
import java.io.IOException;
29+
import java.util.ArrayList;
2930
import java.util.HashMap;
31+
import java.util.List;
3032
import java.util.Map;
3133

3234
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
@@ -40,13 +42,35 @@ public static TestModel createRandomInstance() {
4042
}
4143

4244
public static TestModel createRandomInstance(TaskType taskType) {
45+
return createRandomInstance(taskType, null);
46+
}
47+
48+
public static TestModel createRandomInstance(TaskType taskType, List<SimilarityMeasure> excludedSimilarities) {
4349
var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null;
4450
var dimensions = taskType == TaskType.TEXT_EMBEDDING
4551
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 64)
4652
: null;
47-
var similarity = taskType == TaskType.TEXT_EMBEDDING
48-
? randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
49-
: null;
53+
54+
SimilarityMeasure similarity = null;
55+
if (taskType == TaskType.TEXT_EMBEDDING) {
56+
List<SimilarityMeasure> supportedSimilarities = new ArrayList<>(
57+
DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType)
58+
);
59+
if (excludedSimilarities != null) {
60+
supportedSimilarities.removeAll(excludedSimilarities);
61+
}
62+
63+
if (supportedSimilarities.isEmpty()) {
64+
throw new IllegalArgumentException(
65+
"No supported similarities for combination of element type ["
66+
+ elementType
67+
+ "] and excluded similarities "
68+
+ (excludedSimilarities == null ? List.of() : excludedSimilarities)
69+
);
70+
}
71+
72+
similarity = randomFrom(supportedSimilarities);
73+
}
5074

5175
return new TestModel(
5276
randomAlphaOfLength(4),

0 commit comments

Comments
 (0)