Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/120751.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 120751
summary: Adding support for binary embedding type to Cohere service embedding type
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ static TransportVersion def(int id) {
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0);
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0);
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(8_840_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";

public InferenceByteEmbedding(StreamInput in) throws IOException {
this(in.readByteArray());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeByteArray(values);
}

public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
byte[] embeddingValues = new byte[embeddingValuesList.size()];
for (int i = 0; i < embeddingValuesList.size(); i++) {
embeddingValues[i] = embeddingValuesList.get(i);
}
return new InferenceByteEmbedding(embeddingValues);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.startArray(EMBEDDING);
for (byte value : values) {
builder.value(value);
}
builder.endArray();

builder.endObject();
return builder;
}

@Override
public String toString() {
return Strings.toString(this);
}

float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}

double[] toDoubleArray() {
double[] doubleArray = new double[values.length];
for (int i = 0; i < values.length; i++) {
doubleArray[i] = ((Byte) values[i]).doubleValue();
}
return doubleArray;
}

@Override
public int getSize() {
return values().length;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
return Arrays.equals(values, embedding.values);
}

@Override
public int hashCode() {
return Arrays.hashCode(values);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* Writes a text embedding result in the follow json format
* {
* "text_embedding_bytes": [
* {
* "embedding": [
* 23
* ]
* },
* {
* "embedding": [
* -23
* ]
* }
* ]
* }
*/
public record InferenceTextEmbeddingBitResults(List<InferenceByteEmbedding> embeddings) implements InferenceServiceResults, TextEmbedding {
public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";

public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(InferenceByteEmbedding::new));
}

@Override
public int getFirstEmbeddingSize() {
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BITS, embeddings.iterator());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(embeddings);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
.toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
);

return List.of(legacyEmbedding);
}

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING_BITS, embeddings);

return map;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceTextEmbeddingBitResults that = (InferenceTextEmbeddingBitResults) o;
return Objects.equals(embeddings, that.embeddings);
}

@Override
public int hashCode() {
return Objects.hash(embeddings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,16 @@

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
Expand All @@ -33,7 +28,7 @@
/**
* Writes a text embedding result in the follow json format
* {
* "text_embedding": [
* "text_embedding_bytes": [
* {
* "embedding": [
* 23
Expand Down Expand Up @@ -111,78 +106,4 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(embeddings);
}

public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";

public InferenceByteEmbedding(StreamInput in) throws IOException {
this(in.readByteArray());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeByteArray(values);
}

public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
byte[] embeddingValues = new byte[embeddingValuesList.size()];
for (int i = 0; i < embeddingValuesList.size(); i++) {
embeddingValues[i] = embeddingValuesList.get(i);
}
return new InferenceByteEmbedding(embeddingValues);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.startArray(EMBEDDING);
for (byte value : values) {
builder.value(value);
}
builder.endArray();

builder.endObject();
return builder;
}

@Override
public String toString() {
return Strings.toString(this);
}

private float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}

private double[] toDoubleArray() {
double[] doubleArray = new double[values.length];
for (int i = 0; i < values.length; i++) {
doubleArray[i] = ((Byte) values[i]).floatValue();
}
return doubleArray;
}

@Override
public int getSize() {
return values().length;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
return Arrays.equals(values, embedding.values);
}

@Override
public int hashCode() {
return Arrays.hashCode(values);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand Down Expand Up @@ -69,7 +70,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El

private List<ChunkOffsetsAndInput> chunkedOffsets;
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
private AtomicArray<Exception> errors;
private ActionListener<List<ChunkedInference>> finalListener;
Expand Down Expand Up @@ -389,9 +390,9 @@ private ChunkedInferenceEmbeddingFloat mergeFloatResultsWithInputs(

private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
ChunkOffsetsAndInput chunks,
AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>> debatchedResults
AtomicArray<List<InferenceByteEmbedding>> debatchedResults
) {
var all = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
var all = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < debatchedResults.length(); i++) {
var subBatch = debatchedResults.get(i);
all.addAll(subBatch);
Expand Down
Loading