Skip to content

Commit 69fadc7

Browse files
Mikep86georgewallace
authored andcommitted
Add bit vector support to semantic text (elastic#123187)
1 parent 2bf626b commit 69fadc7

File tree

12 files changed

+270
-101
lines changed

12 files changed

+270
-101
lines changed

docs/changelog/123187.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 123187
2+
summary: Add bit vector support to semantic text
3+
area: Vector Search
4+
type: enhancement
5+
issues: []
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.mapper.vectors;
11+
12+
import com.carrotsearch.randomizedtesting.RandomizedContext;
13+
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
14+
15+
import org.elasticsearch.inference.SimilarityMeasure;
16+
17+
import java.util.List;
18+
import java.util.Random;
19+
20+
public class DenseVectorFieldMapperTestUtils {
21+
private DenseVectorFieldMapperTestUtils() {}
22+
23+
public static List<SimilarityMeasure> getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) {
24+
return switch (elementType) {
25+
case FLOAT, BYTE -> List.of(SimilarityMeasure.values());
26+
case BIT -> List.of(SimilarityMeasure.L2_NORM);
27+
};
28+
}
29+
30+
public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
31+
return switch (elementType) {
32+
case FLOAT, BYTE -> dimensions;
33+
case BIT -> {
34+
assert dimensions % Byte.SIZE == 0;
35+
yield dimensions / Byte.SIZE;
36+
}
37+
};
38+
}
39+
40+
public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType elementType, int max) {
41+
if (max < 1) {
42+
throw new IllegalArgumentException("max must be at least 1");
43+
}
44+
45+
return switch (elementType) {
46+
case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max);
47+
case BIT -> {
48+
if (max < 8) {
49+
throw new IllegalArgumentException("max must be at least 8 for bit vectors");
50+
}
51+
52+
// Generate a random dimension count that is a multiple of 8
53+
int maxEmbeddingLength = max / 8;
54+
yield RandomNumbers.randomIntBetween(random(), 1, maxEmbeddingLength) * 8;
55+
}
56+
};
57+
}
58+
59+
private static Random random() {
60+
return RandomizedContext.current().getRandom();
61+
}
62+
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public void infer(
119119
switch (model.getConfigurations().getTaskType()) {
120120
case ANY, TEXT_EMBEDDING -> {
121121
ServiceSettings modelServiceSettings = model.getServiceSettings();
122-
listener.onResponse(makeResults(input, modelServiceSettings.dimensions()));
122+
listener.onResponse(makeResults(input, modelServiceSettings));
123123
}
124124
default -> listener.onFailure(
125125
new ElasticsearchStatusException(
@@ -153,7 +153,7 @@ public void chunkedInfer(
153153
switch (model.getConfigurations().getTaskType()) {
154154
case ANY, TEXT_EMBEDDING -> {
155155
ServiceSettings modelServiceSettings = model.getServiceSettings();
156-
listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions()));
156+
listener.onResponse(makeChunkedResults(input, modelServiceSettings));
157157
}
158158
default -> listener.onFailure(
159159
new ElasticsearchStatusException(
@@ -164,17 +164,17 @@ public void chunkedInfer(
164164
}
165165
}
166166

167-
private TextEmbeddingFloatResults makeResults(List<String> input, int dimensions) {
167+
private TextEmbeddingFloatResults makeResults(List<String> input, ServiceSettings serviceSettings) {
168168
List<TextEmbeddingFloatResults.Embedding> embeddings = new ArrayList<>();
169169
for (String inputString : input) {
170-
List<Float> floatEmbeddings = generateEmbedding(inputString, dimensions);
170+
List<Float> floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
171171
embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings));
172172
}
173173
return new TextEmbeddingFloatResults(embeddings);
174174
}
175175

176-
private List<ChunkedInference> makeChunkedResults(List<String> input, int dimensions) {
177-
TextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions);
176+
private List<ChunkedInference> makeChunkedResults(List<String> input, ServiceSettings serviceSettings) {
177+
TextEmbeddingFloatResults nonChunkedResults = makeResults(input, serviceSettings);
178178

179179
var results = new ArrayList<ChunkedInference>();
180180
for (int i = 0; i < input.size(); i++) {
@@ -204,7 +204,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
204204
* <ul>
205205
* <li>Unique to the input</li>
206206
* <li>Reproducible (i.e given the same input, the same embedding should be generated)</li>
207-
* <li>Valid as both a float and byte embedding</li>
207+
* <li>Valid for the provided element type</li>
208208
* </ul>
209209
* <p>
210210
* The embedding is generated by:
@@ -219,32 +219,48 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
219219
* Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8
220220
* encoded byte array is guaranteed to only contain values in the standard ASCII table.
221221
* </p>
222+
* <p>
223+
* If a bit embedding is required, the embedding length is 1/8 the dimension count because eight dimensions are encoded into each
224+
* embedding byte.
225+
* </p>
222226
*
223227
* @param input The input string
224228
* @param dimensions The embedding dimension count
225229
* @return An embedding
226230
*/
227-
private static List<Float> generateEmbedding(String input, int dimensions) {
228-
List<Float> embedding = new ArrayList<>(dimensions);
231+
private static List<Float> generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) {
232+
int embeddingLength = getEmbeddingLength(elementType, dimensions);
233+
List<Float> embedding = new ArrayList<>(embeddingLength);
229234

230235
byte[] byteArray = Integer.toString(input.hashCode()).getBytes(StandardCharsets.UTF_8);
231236
List<Float> embeddingValues = new ArrayList<>(byteArray.length);
232237
for (byte value : byteArray) {
233238
embeddingValues.add((float) value);
234239
}
235240

236-
int remainingDimensions = dimensions;
237-
while (remainingDimensions >= embeddingValues.size()) {
241+
int remainingLength = embeddingLength;
242+
while (remainingLength >= embeddingValues.size()) {
238243
embedding.addAll(embeddingValues);
239-
remainingDimensions -= embeddingValues.size();
244+
remainingLength -= embeddingValues.size();
240245
}
241-
if (remainingDimensions > 0) {
242-
embedding.addAll(embeddingValues.subList(0, remainingDimensions));
246+
if (remainingLength > 0) {
247+
embedding.addAll(embeddingValues.subList(0, remainingLength));
243248
}
244249

245250
return embedding;
246251
}
247252

253+
// Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions
254+
private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
255+
return switch (elementType) {
256+
case FLOAT, BYTE -> dimensions;
257+
case BIT -> {
258+
assert dimensions % Byte.SIZE == 0;
259+
yield dimensions / Byte.SIZE;
260+
}
261+
};
262+
}
263+
248264
public static class Configuration {
249265
public static InferenceServiceConfiguration get() {
250266
return configuration.getOrCompute();
@@ -282,12 +298,6 @@ public record TestServiceSettings(
282298

283299
static final String NAME = "test_text_embedding_service_settings";
284300

285-
public TestServiceSettings {
286-
if (elementType == DenseVectorFieldMapper.ElementType.BIT) {
287-
throw new IllegalArgumentException("Test dense inference service does not yet support element type BIT");
288-
}
289-
}
290-
291301
public static TestServiceSettings fromMap(Map<String, Object> map) {
292302
ValidationException validationException = new ValidationException();
293303

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
2525
import org.elasticsearch.index.mapper.SourceFieldMapper;
2626
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
27+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
2728
import org.elasticsearch.inference.SimilarityMeasure;
2829
import org.elasticsearch.license.LicenseSettings;
2930
import org.elasticsearch.plugins.Plugin;
@@ -35,7 +36,6 @@
3536
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
3637
import org.junit.Before;
3738

38-
import java.util.Arrays;
3939
import java.util.Collection;
4040
import java.util.HashMap;
4141
import java.util.HashSet;
@@ -71,15 +71,16 @@ public static Iterable<Object[]> parameters() throws Exception {
7171

7272
@Before
7373
public void setup() throws Exception {
74-
Utils.storeSparseModel(client());
75-
Utils.storeDenseModel(
76-
client(),
77-
randomIntBetween(1, 100),
78-
// dot product means that we need normalized vectors; it's not worth doing that in this test
79-
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),
80-
// TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it
81-
randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values()))
74+
DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
75+
// dot product means that we need normalized vectors; it's not worth doing that in this test
76+
SimilarityMeasure similarity = randomValueOtherThan(
77+
SimilarityMeasure.DOT_PRODUCT,
78+
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
8279
);
80+
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
81+
82+
Utils.storeSparseModel(client());
83+
Utils.storeDenseModel(client(), dimensions, similarity, elementType);
8384
}
8485

8586
@Override
@@ -89,7 +90,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
8990

9091
@Override
9192
protected Collection<Class<? extends Plugin>> nodePlugins() {
92-
return Arrays.asList(LocalStateInferencePlugin.class);
93+
return List.of(LocalStateInferencePlugin.class);
9394
}
9495

9596
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ public Set<NodeFeature> getTestFeatures() {
4949
SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT,
5050
SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT,
5151
SEMANTIC_KNN_FILTER_FIX,
52-
TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE
52+
TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE,
53+
SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT
5354
);
5455
}
5556
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
118118
"semantic_text.always_emit_inference_id_fix"
119119
);
120120
public static final NodeFeature SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS = new NodeFeature("semantic_text.skip_inference_fields");
121+
public static final NodeFeature SEMANTIC_TEXT_BIT_VECTOR_SUPPORT = new NodeFeature("semantic_text.bit_vector_support");
121122

122123
public static final String CONTENT_TYPE = "semantic_text";
123124
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
@@ -709,12 +710,12 @@ yield new SparseVectorQueryBuilder(
709710

710711
MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
711712
float[] inference = textEmbeddingResults.getInferenceAsFloat();
712-
var inferenceLength = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
713-
? inference.length * Byte.SIZE
713+
int dimensions = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
714+
? inference.length * Byte.SIZE // Bit vectors encode 8 dimensions into each byte value
714715
: inference.length;
715-
if (inferenceLength != modelSettings.dimensions()) {
716+
if (dimensions != modelSettings.dimensions()) {
716717
throw new IllegalArgumentException(
717-
generateDimensionCountMismatchMessage(inferenceLength, modelSettings.dimensions())
718+
generateDimensionCountMismatchMessage(dimensions, modelSettings.dimensions())
718719
);
719720
}
720721

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@
3333
import org.elasticsearch.common.xcontent.support.XContentMapValues;
3434
import org.elasticsearch.index.IndexVersion;
3535
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
36-
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
3736
import org.elasticsearch.index.shard.ShardId;
3837
import org.elasticsearch.inference.ChunkedInference;
3938
import org.elasticsearch.inference.InferenceService;
4039
import org.elasticsearch.inference.InferenceServiceRegistry;
4140
import org.elasticsearch.inference.Model;
42-
import org.elasticsearch.inference.SimilarityMeasure;
4341
import org.elasticsearch.inference.TaskType;
4442
import org.elasticsearch.inference.UnparsedModel;
4543
import org.elasticsearch.license.MockLicenseState;
@@ -650,7 +648,7 @@ private static class StaticModel extends TestModel {
650648
}
651649

652650
public static StaticModel createRandomInstance() {
653-
TestModel testModel = randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING));
651+
TestModel testModel = TestModel.createRandomInstance();
654652
return new StaticModel(
655653
testModel.getInferenceEntityId(),
656654
testModel.getTaskType(),
@@ -673,18 +671,4 @@ boolean hasResult(String text) {
673671
return resultMap.containsKey(text);
674672
}
675673
}
676-
677-
private static TestModel randomModel(TaskType taskType) {
678-
var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomIntBetween(2, 64) : null;
679-
var similarity = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(SimilarityMeasure.values()) : null;
680-
var elementType = taskType == TaskType.TEXT_EMBEDDING ? DenseVectorFieldMapper.ElementType.FLOAT : null;
681-
return new TestModel(
682-
randomAlphaOfLength(4),
683-
taskType,
684-
randomAlphaOfLength(10),
685-
new TestModel.TestServiceSettings(randomAlphaOfLength(4), dimensions, similarity, elementType),
686-
new TestModel.TestTaskSettings(randomInt(3)),
687-
new TestModel.TestSecretSettings(randomAlphaOfLength(4))
688-
);
689-
}
690674
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.elasticsearch.index.mapper.MapperService;
2626
import org.elasticsearch.index.mapper.SourceFieldMapper;
2727
import org.elasticsearch.index.mapper.SourceToParse;
28-
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2928
import org.elasticsearch.index.translog.Translog;
3029
import org.elasticsearch.inference.ChunkedInference;
3130
import org.elasticsearch.inference.Model;
@@ -44,6 +43,7 @@
4443

4544
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
4645
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingByte;
46+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingFloat;
4747
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse;
4848
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults;
4949
import static org.hamcrest.Matchers.equalTo;
@@ -55,8 +55,8 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase
5555
private final boolean useIncludesExcludes;
5656

5757
public SemanticInferenceMetadataFieldsRecoveryTests(boolean useSynthetic, boolean useIncludesExcludes) {
58-
this.model1 = randomModel(TaskType.TEXT_EMBEDDING);
59-
this.model2 = randomModel(TaskType.SPARSE_EMBEDDING);
58+
this.model1 = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT));
59+
this.model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
6060
this.useSynthetic = useSynthetic;
6161
this.useIncludesExcludes = useIncludesExcludes;
6262
}
@@ -218,22 +218,6 @@ private Translog.Snapshot newRandomSnapshot(
218218
}
219219
}
220220

221-
private static Model randomModel(TaskType taskType) {
222-
var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomIntBetween(2, 64) : null;
223-
var similarity = taskType == TaskType.TEXT_EMBEDDING
224-
? randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values()))
225-
: null;
226-
var elementType = taskType == TaskType.TEXT_EMBEDDING ? DenseVectorFieldMapper.ElementType.BYTE : null;
227-
return new TestModel(
228-
randomAlphaOfLength(4),
229-
taskType,
230-
randomAlphaOfLength(10),
231-
new TestModel.TestServiceSettings(randomAlphaOfLength(4), dimensions, similarity, elementType),
232-
new TestModel.TestTaskSettings(randomInt(3)),
233-
new TestModel.TestSecretSettings(randomAlphaOfLength(4))
234-
);
235-
}
236-
237221
private BytesReference randomSource() throws IOException {
238222
var builder = JsonXContent.contentBuilder().startObject();
239223
builder.field("field", randomAlphaOfLengthBetween(10, 30));
@@ -261,8 +245,8 @@ private static SemanticTextField randomSemanticText(
261245
) throws IOException {
262246
ChunkedInference results = switch (model.getTaskType()) {
263247
case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) {
264-
case BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs);
265-
default -> throw new AssertionError("invalid element type: " + model.getServiceSettings().elementType().name());
248+
case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs);
249+
case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs);
266250
};
267251
case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false);
268252
default -> throw new AssertionError("invalid task type: " + model.getTaskType().name());

0 commit comments

Comments
 (0)