Skip to content

Commit 4371b69

Browse files
committed
Better parameterized test
1 parent 7d2625c commit 4371b69

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import java.util.Locale;
3030
import java.util.Map;
3131
import java.util.Set;
32-
import java.util.function.Function;
33-
import java.util.function.Supplier;
3432

3533
import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING;
3634
import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC;
@@ -54,36 +52,33 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
5452
private final ElementType elementType;
5553
private final DenseVectorFieldMapper.VectorSimilarity similarity;
5654
private final boolean synthetic;
57-
private final String indexType;
5855
private final boolean index;
5956

6057
@ParametersFactory
6158
public static Iterable<Object[]> parameters() throws Exception {
6259
List<Object[]> params = new ArrayList<>();
63-
// Indexed field types
64-
Supplier<ElementType> elementTypeProvider = () -> ElementType.FLOAT;
65-
Function<ElementType, String> indexTypeProvider = e -> randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES);
66-
Supplier<DenseVectorFieldMapper.VectorSimilarity> vectorSimilarityProvider = () -> randomFrom(
67-
DenseVectorFieldMapper.VectorSimilarity.values()
68-
);
69-
params.add(new Object[] { elementTypeProvider, indexTypeProvider, vectorSimilarityProvider, true, false });
60+
61+
// Test all similarities
62+
for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) {
63+
params.add(new Object[] { ElementType.FLOAT, similarity, true, false });
64+
}
65+
7066
// No indexing
71-
params.add(new Object[] { elementTypeProvider, null, null, false, false });
67+
params.add(new Object[] { ElementType.FLOAT, null, false, false });
7268
// No indexing, synthetic source
73-
params.add(new Object[] { elementTypeProvider, null, null, false, true });
69+
params.add(new Object[] { ElementType.FLOAT, null, false, true });
70+
7471
return params;
7572
}
7673

7774
public DenseVectorFieldTypeIT(
78-
@Name("elementType") Supplier<ElementType> elementTypeProvider,
79-
@Name("indexType") Function<ElementType, String> indexTypeProvider,
80-
@Name("similarity") Supplier<DenseVectorFieldMapper.VectorSimilarity> similarityProvider,
75+
@Name("elementType") ElementType elementType,
76+
@Name("similarity") DenseVectorFieldMapper.VectorSimilarity similarity,
8177
@Name("index") boolean index,
8278
@Name("synthetic") boolean synthetic
8379
) {
84-
this.elementType = elementTypeProvider.get();
85-
this.indexType = indexTypeProvider == null ? null : indexTypeProvider.apply(this.elementType);
86-
this.similarity = similarityProvider == null ? null : similarityProvider.get();
80+
this.elementType = elementType;
81+
this.similarity = similarity;
8782
this.index = index;
8883
this.synthetic = synthetic;
8984
}
@@ -207,7 +202,11 @@ public void setup() throws IOException {
207202
indexedVectors.put(i, null);
208203
} else {
209204
for (int j = 0; j < numDims; j++) {
210-
vector.add(randomFloatBetween(0F, 1F, true));
205+
switch (elementType) {
206+
case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true));
207+
case BYTE -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f));
208+
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
209+
}
211210
}
212211
if ((similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) || rarely()) {
213212
// Normalize the vector
@@ -236,8 +235,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException {
236235
.field("index", index);
237236
if (index) {
238237
mapping.field("similarity", similarity.name().toLowerCase(Locale.ROOT));
239-
}
240-
if (indexType != null) {
238+
String indexType = randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES);
241239
mapping.startObject("index_options").field("type", indexType).endObject();
242240
}
243241
mapping.endObject().endObject().endObject();

0 commit comments

Comments
 (0)