Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.mapper.vectors;

import com.carrotsearch.randomizedtesting.RandomizedContext;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;

import org.elasticsearch.inference.SimilarityMeasure;

import java.util.List;
import java.util.Random;

public class DenseVectorFieldMapperTestUtils {
private DenseVectorFieldMapperTestUtils() {}

public static List<SimilarityMeasure> getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) {
return switch (elementType) {
case FLOAT, BYTE -> List.of(SimilarityMeasure.values());
case BIT -> List.of(SimilarityMeasure.L2_NORM);
};
}

public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
return switch (elementType) {
case FLOAT, BYTE -> dimensions;
case BIT -> {
assert dimensions % Byte.SIZE == 0;
yield dimensions / Byte.SIZE;
}
};
}

public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType elementType, int max) {
if (max < 1) {
throw new IllegalArgumentException("max must be at least 1");
}

return switch (elementType) {
case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max);
case BIT -> {
if (max < 8) {
throw new IllegalArgumentException("max must be at least 8 for bit vectors");
}

// Generate a random dimension count that is a multiple of 8
int maxEmbeddingLength = max / 8;
yield RandomNumbers.randomIntBetween(random(), 1, maxEmbeddingLength) * 8;
}
};
}

private static Random random() {
return RandomizedContext.current().getRandom();
}
}
1 change: 1 addition & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.elasticsearch.gradle.internal.info.BuildParams
apply plugin: 'elasticsearch.internal-es-plugin'
apply plugin: 'elasticsearch.internal-cluster-test'
apply plugin: 'elasticsearch.internal-yaml-rest-test'
apply plugin: 'elasticsearch.internal-test-artifact'

restResources {
restApi {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ private static void addSemanticTextMapping(
mappingBuilder.endObject();
}

private static void addSemanticTextInferenceResults(XContentBuilder sourceBuilder, List<SemanticTextField> semanticTextInferenceResults)
public static void addSemanticTextInferenceResults(XContentBuilder sourceBuilder, List<SemanticTextField> semanticTextInferenceResults)
throws IOException {
for (var field : semanticTextInferenceResults) {
sourceBuilder.field(field.fieldName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand All @@ -25,7 +26,10 @@
import org.elasticsearch.xpack.inference.services.ServiceUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
Expand All @@ -39,9 +43,59 @@ public static TestModel createRandomInstance() {
}

public static TestModel createRandomInstance(TaskType taskType) {
var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomInt(64) : null;
var similarity = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(SimilarityMeasure.values()) : null;
var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(DenseVectorFieldMapper.ElementType.values()) : null;
return createRandomInstance(taskType, null, null);
}

public static TestModel createRandomInstance(
TaskType taskType,
List<DenseVectorFieldMapper.ElementType> excludedElementTypes,
List<SimilarityMeasure> excludedSimilarities
) {
return createRandomInstance(taskType, excludedElementTypes, excludedSimilarities, 128);
}

public static TestModel createRandomInstance(
TaskType taskType,
List<DenseVectorFieldMapper.ElementType> excludedElementTypes,
List<SimilarityMeasure> excludedSimilarities,
int maxDimensions
) {
List<DenseVectorFieldMapper.ElementType> supportedElementTypes = new ArrayList<>(
Arrays.asList(DenseVectorFieldMapper.ElementType.values())
);
if (excludedElementTypes != null) {
supportedElementTypes.removeAll(excludedElementTypes);
if (supportedElementTypes.isEmpty()) {
throw new IllegalArgumentException("No supported element types with excluded element types " + excludedElementTypes);
}
}

var elementType = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(supportedElementTypes) : null;
var dimensions = taskType == TaskType.TEXT_EMBEDDING
? DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions)
: null;

SimilarityMeasure similarity = null;
if (taskType == TaskType.TEXT_EMBEDDING) {
List<SimilarityMeasure> supportedSimilarities = new ArrayList<>(
DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType)
);
if (excludedSimilarities != null) {
supportedSimilarities.removeAll(excludedSimilarities);
}

if (supportedSimilarities.isEmpty()) {
throw new IllegalArgumentException(
"No supported similarities for combination of element type ["
+ elementType
+ "] and excluded similarities "
+ (excludedSimilarities == null ? List.of() : excludedSimilarities)
);
}

similarity = randomFrom(supportedSimilarities);
}

return new TestModel(
randomAlphaOfLength(4),
taskType,
Expand Down
2 changes: 2 additions & 0 deletions x-pack/qa/rolling-upgrade/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ apply plugin: 'elasticsearch.bwc-test'
apply plugin: 'elasticsearch.rest-resources'

dependencies {
testImplementation testArtifact(project(':server'))
testImplementation testArtifact(project(xpackModule('core')))
testImplementation project(':x-pack:qa')
testImplementation project(':modules:reindex')
testImplementation testArtifact(project(xpackModule('inference')))
}

restResources {
Expand Down
Loading