Skip to content

Commit dc3e387

Browse files
committed
Refactored to remove getEmbeddingLength and getDimensions methods
1 parent 0815989 commit dc3e387

File tree

8 files changed

+40
-72
lines changed

8 files changed

+40
-72
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -521,16 +521,6 @@ public int getNumBytes(int dimensions) {
521521
return dimensions;
522522
}
523523

524-
@Override
525-
public int getEmbeddingLength(int dimensions) {
526-
return dimensions;
527-
}
528-
529-
@Override
530-
public int getDimensions(int embeddingLength) {
531-
return embeddingLength;
532-
}
533-
534524
@Override
535525
public ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) {
536526
return ByteBuffer.wrap(new byte[numBytes]);
@@ -729,16 +719,6 @@ public int getNumBytes(int dimensions) {
729719
return dimensions * Float.BYTES;
730720
}
731721

732-
@Override
733-
public int getEmbeddingLength(int dimensions) {
734-
return dimensions;
735-
}
736-
737-
@Override
738-
public int getDimensions(int embeddingLength) {
739-
return embeddingLength;
740-
}
741-
742722
@Override
743723
public ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) {
744724
return indexVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)
@@ -953,17 +933,6 @@ public int getNumBytes(int dimensions) {
953933
return dimensions / Byte.SIZE;
954934
}
955935

956-
@Override
957-
public int getEmbeddingLength(int dimensions) {
958-
assert dimensions % Byte.SIZE == 0;
959-
return dimensions / Byte.SIZE;
960-
}
961-
962-
@Override
963-
public int getDimensions(int embeddingLength) {
964-
return embeddingLength * Byte.SIZE;
965-
}
966-
967936
@Override
968937
public ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) {
969938
return ByteBuffer.wrap(new byte[numBytes]);
@@ -1022,10 +991,6 @@ public abstract VectorData parseKnnVector(
1022991

1023992
public abstract int getNumBytes(int dimensions);
1024993

1025-
public abstract int getEmbeddingLength(int dimensions);
1026-
1027-
public abstract int getDimensions(int embeddingLength);
1028-
1029994
public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
1030995

1031996
public abstract void checkVectorBounds(float[] vector);

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,6 @@ public MinimalServiceSettings(
124124
validate();
125125
}
126126

127-
public Integer embeddingLength() {
128-
if (taskType != TEXT_EMBEDDING) {
129-
return null;
130-
}
131-
132-
return elementType.getEmbeddingLength(dimensions);
133-
}
134-
135127
@Override
136128
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
137129
builder.startObject();

server/src/main/java/org/elasticsearch/inference/ServiceSettings.java

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,6 @@ default DenseVectorFieldMapper.ElementType elementType() {
5353
return null;
5454
}
5555

56-
/**
57-
* The number of numeric values in the embedding. Usually the same as the number of dimensions, but differs if multiple dimensions
58-
* are encoded into a single value in the embedding. Will be null if not applicable.
59-
*
60-
* @return The number of numeric values in the embedding
61-
*/
62-
default Integer embeddingLength() {
63-
DenseVectorFieldMapper.ElementType elementType = elementType();
64-
Integer dimensions = dimensions();
65-
if (elementType == null || dimensions == null) {
66-
return null;
67-
}
68-
69-
return elementType.getEmbeddingLength(dimensions);
70-
}
71-
7256
/**
7357
* The model to use in the inference endpoint (e.g. text-embedding-ada-002). Sometimes the model is not defined in the service
7458
* settings. This can happen for external providers (e.g. hugging face, azure ai studio) where the provider requires that the model

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ public static List<SimilarityMeasure> getSupportedSimilarities(DenseVectorFieldM
2727
};
2828
}
2929

30+
public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
31+
return switch (elementType) {
32+
case FLOAT, BYTE -> dimensions;
33+
case BIT -> {
34+
assert dimensions % Byte.SIZE == 0;
35+
yield dimensions / Byte.SIZE;
36+
}
37+
};
38+
}
39+
3040
public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType elementType, int max) {
3141
if (max < 1) {
3242
throw new IllegalArgumentException("max must be at least 1");

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
230230
* @return An embedding
231231
*/
232232
private static List<Float> generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) {
233-
int embeddingLength = elementType.getEmbeddingLength(dimensions);
233+
int embeddingLength = getEmbeddingLength(elementType, dimensions);
234234
List<Float> embedding = new ArrayList<>(embeddingLength);
235235

236236
byte[] byteArray = Integer.toString(input.hashCode()).getBytes(StandardCharsets.UTF_8);
@@ -251,6 +251,16 @@ private static List<Float> generateEmbedding(String input, int dimensions, Dense
251251
return embedding;
252252
}
253253

254+
private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
255+
return switch (elementType) {
256+
case FLOAT, BYTE -> dimensions;
257+
case BIT -> {
258+
assert dimensions % Byte.SIZE == 0;
259+
yield dimensions / Byte.SIZE;
260+
}
261+
};
262+
}
263+
254264
public static class Configuration {
255265
public static InferenceServiceConfiguration get() {
256266
return configuration.getOrCompute();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,9 @@ yield new SparseVectorQueryBuilder(
710710

711711
MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
712712
float[] inference = textEmbeddingResults.getInferenceAsFloat();
713-
int dimensions = modelSettings.elementType().getDimensions(inference.length);
713+
int dimensions = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
714+
? inference.length * Byte.SIZE // Bit vectors encode 8 dimensions into each byte value
715+
: inference.length;
714716
if (dimensions != modelSettings.dimensions()) {
715717
throw new IllegalArgumentException(
716718
generateDimensionCountMismatchMessage(dimensions, modelSettings.dimensions())

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.bytes.BytesReference;
1313
import org.elasticsearch.common.xcontent.XContentHelper;
1414
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
15+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
1516
import org.elasticsearch.inference.ChunkedInference;
1617
import org.elasticsearch.inference.MinimalServiceSettings;
1718
import org.elasticsearch.inference.Model;
@@ -82,8 +83,10 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic
8283
assertThat(actualChunk.endOffset(), equalTo(expectedChunks.get(i).endOffset()));
8384
switch (modelSettings.taskType()) {
8485
case TEXT_EMBEDDING -> {
85-
Integer embeddingLength = modelSettings.embeddingLength();
86-
assert embeddingLength != null;
86+
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(
87+
modelSettings.elementType(),
88+
modelSettings.dimensions()
89+
);
8790

8891
double[] expectedVector = parseDenseVector(
8992
expectedChunks.get(i).rawEmbeddings(),
@@ -173,9 +176,8 @@ public void testModelSettingsValidation() {
173176

174177
public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Model model, List<String> inputs) {
175178
DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType();
176-
Integer embeddingLength = model.getServiceSettings().embeddingLength();
179+
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions());
177180
assert elementType == DenseVectorFieldMapper.ElementType.BYTE || elementType == DenseVectorFieldMapper.ElementType.BIT;
178-
assert embeddingLength != null;
179181

180182
List<TextEmbeddingByteResults.Chunk> chunks = new ArrayList<>();
181183
for (String input : inputs) {
@@ -189,9 +191,9 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode
189191
}
190192

191193
public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Model model, List<String> inputs) {
192-
Integer embeddingLength = model.getServiceSettings().embeddingLength();
193-
assert model.getServiceSettings().elementType() == DenseVectorFieldMapper.ElementType.FLOAT;
194-
assert embeddingLength != null;
194+
DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType();
195+
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions());
196+
assert elementType == DenseVectorFieldMapper.ElementType.FLOAT;
195197

196198
List<TextEmbeddingFloatResults.Chunk> chunks = new ArrayList<>();
197199
for (String input : inputs) {
@@ -324,10 +326,12 @@ public static ChunkedInference toChunkedResult(
324326
return new ChunkedInferenceEmbedding(chunks);
325327
}
326328
case TEXT_EMBEDDING -> {
327-
Integer embeddingLength = field.inference().modelSettings().embeddingLength();
328-
assert embeddingLength != null;
329-
330329
var elementType = field.inference().modelSettings().elementType();
330+
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(
331+
elementType,
332+
field.inference().modelSettings().dimensions()
333+
);
334+
331335
List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
332336
for (var entry : field.inference().chunks().entrySet()) {
333337
String entryField = entry.getKey();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.index.mapper.ParsedDocument;
3535
import org.elasticsearch.index.mapper.SourceToParse;
3636
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
37+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
3738
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
3839
import org.elasticsearch.index.query.QueryBuilder;
3940
import org.elasticsearch.index.query.QueryRewriteContext;
@@ -292,7 +293,7 @@ private InferenceAction.Response generateSparseEmbeddingInferenceResponse(String
292293
}
293294

294295
private InferenceAction.Response generateTextEmbeddingInferenceResponse() {
295-
int inferenceLength = denseVectorElementType.getEmbeddingLength(TEXT_EMBEDDING_DIMENSION_COUNT);
296+
int inferenceLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(denseVectorElementType, TEXT_EMBEDDING_DIMENSION_COUNT);
296297
double[] inference = new double[inferenceLength];
297298
Arrays.fill(inference, 1.0);
298299
MlTextEmbeddingResults textEmbeddingResults = new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false);

0 commit comments

Comments
 (0)