diff --git a/codegen/nuget.config b/codegen/nuget.config new file mode 100644 index 000000000..6cba8660c --- /dev/null +++ b/codegen/nuget.config @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/src/Custom/Embeddings/OpenAIEmbedding.cs b/src/Custom/Embeddings/OpenAIEmbedding.cs index 408838afe..c41eceb3f 100644 --- a/src/Custom/Embeddings/OpenAIEmbedding.cs +++ b/src/Custom/Embeddings/OpenAIEmbedding.cs @@ -4,6 +4,7 @@ using System.Buffers.Text; using System.Collections.Generic; using System.Runtime.InteropServices; +using System.Text.Json; namespace OpenAI.Embeddings; @@ -107,13 +108,18 @@ internal OpenAIEmbedding(int index, ReadOnlyMemory vector) // CUSTOM: Implemented custom logic to transform from BinaryData to ReadOnlyMemory. private static ReadOnlyMemory ConvertToVectorOfFloats(BinaryData binaryData) { - ReadOnlySpan base64 = binaryData.ToMemory().Span; + ReadOnlySpan bytes = binaryData.ToMemory().Span; // Remove quotes around base64 string. - if (base64.Length < 2 || base64[0] != (byte)'"' || base64[base64.Length - 1] != (byte)'"') + if (bytes.Length > 2 && bytes[0] == (byte)'"' && bytes[bytes.Length - 1] == (byte)'"') { - ThrowInvalidData(); + return ConvertFromBase64(bytes); } + return ConvertFromJsonArray(binaryData); + } + + private static ReadOnlyMemory ConvertFromBase64(ReadOnlySpan base64) + { base64 = base64.Slice(1, base64.Length - 2); // Decode base64 string to bytes. @@ -153,7 +159,33 @@ private static ReadOnlyMemory ConvertToVectorOfFloats(BinaryData binaryDa } } - static void ThrowInvalidData() => - throw new FormatException("The input is not a valid Base64 string of encoded floats."); + static void ThrowInvalidData() + => throw new FormatException("The input is not a valid Base64 string of encoded floats."); + } + + private static ReadOnlyMemory ConvertFromJsonArray(BinaryData jsonArray) + { + using JsonDocument document = JsonDocument.Parse(jsonArray); + JsonElement array = document.RootElement; + if (array.ValueKind != JsonValueKind.Array) + { + throw new FormatException("The input is not a valid JSON array"); + } + + int arrayLength = array.GetArrayLength(); + float[] vector = new float[arrayLength]; + int index = 0; + try + { + foreach (JsonElement value in array.EnumerateArray()) + { + vector[index++] = value.GetSingle(); + } + return vector.AsMemory(); + } + catch + { + throw new FormatException("The input is not a valid JSON array of float values"); + } } } diff --git a/tests/Embeddings/EmbeddingsTests.cs b/tests/Embeddings/EmbeddingsTests.cs index 834a72f27..2a2a3ad2e 100644 --- a/tests/Embeddings/EmbeddingsTests.cs +++ b/tests/Embeddings/EmbeddingsTests.cs @@ -3,6 +3,7 @@ using OpenAI.Tests.Utility; using System; using System.ClientModel; +using System.ClientModel.Primitives; using System.Collections.Generic; using System.Threading.Tasks; using static OpenAI.Tests.TestHelpers; @@ -205,4 +206,34 @@ public void SerializeEmbeddingCollection() { // TODO: Add this test. } + + [Test] + public void JsonArraySupport() + { + string json = """ + { + "object":"list", + "data":[ + { + "object":"embedding", + "embedding":[-0.011229509,0.107915245,-0.15163477] + } + ] + } + """; + + BinaryData binaryData = BinaryData.FromString(json); + + OpenAIEmbeddingCollection embeddings = ModelReaderWriter.Read(binaryData); + + Assert.That(embeddings, Is.Not.Null); + Assert.That(embeddings.Count, Is.EqualTo(1)); + var embedding = embeddings[0]; + Assert.That(embedding, Is.Not.Null); + ReadOnlySpan vector = embedding.ToFloats().Span; + Assert.That(vector.Length, Is.EqualTo(3)); + Assert.That(vector[0], Is.EqualTo(-0.011229509f)); + Assert.That(vector[1], Is.EqualTo(0.107915245f)); + Assert.That(vector[2], Is.EqualTo(-0.15163477f)); + } }