diff --git a/codegen/nuget.config b/codegen/nuget.config
new file mode 100644
index 000000000..6cba8660c
--- /dev/null
+++ b/codegen/nuget.config
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/Custom/Embeddings/OpenAIEmbedding.cs b/src/Custom/Embeddings/OpenAIEmbedding.cs
index 408838afe..c41eceb3f 100644
--- a/src/Custom/Embeddings/OpenAIEmbedding.cs
+++ b/src/Custom/Embeddings/OpenAIEmbedding.cs
@@ -4,6 +4,7 @@
using System.Buffers.Text;
using System.Collections.Generic;
using System.Runtime.InteropServices;
+using System.Text.Json;
namespace OpenAI.Embeddings;
@@ -107,13 +108,18 @@ internal OpenAIEmbedding(int index, ReadOnlyMemory vector)
// CUSTOM: Implemented custom logic to transform from BinaryData to ReadOnlyMemory.
private static ReadOnlyMemory ConvertToVectorOfFloats(BinaryData binaryData)
{
- ReadOnlySpan base64 = binaryData.ToMemory().Span;
+ ReadOnlySpan bytes = binaryData.ToMemory().Span;
// Remove quotes around base64 string.
- if (base64.Length < 2 || base64[0] != (byte)'"' || base64[base64.Length - 1] != (byte)'"')
+ if (bytes.Length > 2 && bytes[0] == (byte)'"' && bytes[bytes.Length - 1] == (byte)'"')
{
- ThrowInvalidData();
+ return ConvertFromBase64(bytes);
}
+ return ConvertFromJsonArray(binaryData);
+ }
+
+ private static ReadOnlyMemory ConvertFromBase64(ReadOnlySpan base64)
+ {
base64 = base64.Slice(1, base64.Length - 2);
// Decode base64 string to bytes.
@@ -153,7 +159,33 @@ private static ReadOnlyMemory ConvertToVectorOfFloats(BinaryData binaryDa
}
}
- static void ThrowInvalidData() =>
- throw new FormatException("The input is not a valid Base64 string of encoded floats.");
+ static void ThrowInvalidData()
+ => throw new FormatException("The input is not a valid Base64 string of encoded floats.");
+ }
+
+ private static ReadOnlyMemory ConvertFromJsonArray(BinaryData jsonArray)
+ {
+ using JsonDocument document = JsonDocument.Parse(jsonArray);
+ JsonElement array = document.RootElement;
+ if (array.ValueKind != JsonValueKind.Array)
+ {
+ throw new FormatException("The input is not a valid JSON array");
+ }
+
+ int arrayLength = array.GetArrayLength();
+ float[] vector = new float[arrayLength];
+ int index = 0;
+ try
+ {
+ foreach (JsonElement value in array.EnumerateArray())
+ {
+ vector[index++] = value.GetSingle();
+ }
+ return vector.AsMemory();
+ }
+ catch
+ {
+ throw new FormatException("The input is not a valid JSON array of float values");
+ }
}
}
diff --git a/tests/Embeddings/EmbeddingsTests.cs b/tests/Embeddings/EmbeddingsTests.cs
index 834a72f27..2a2a3ad2e 100644
--- a/tests/Embeddings/EmbeddingsTests.cs
+++ b/tests/Embeddings/EmbeddingsTests.cs
@@ -3,6 +3,7 @@
using OpenAI.Tests.Utility;
using System;
using System.ClientModel;
+using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Threading.Tasks;
using static OpenAI.Tests.TestHelpers;
@@ -205,4 +206,34 @@ public void SerializeEmbeddingCollection()
{
// TODO: Add this test.
}
+
+ [Test]
+ public void JsonArraySupport()
+ {
+ string json = """
+ {
+ "object":"list",
+ "data":[
+ {
+ "object":"embedding",
+ "embedding":[-0.011229509,0.107915245,-0.15163477]
+ }
+ ]
+ }
+ """;
+
+ BinaryData binaryData = BinaryData.FromString(json);
+
+ OpenAIEmbeddingCollection embeddings = ModelReaderWriter.Read(binaryData);
+
+ Assert.That(embeddings, Is.Not.Null);
+ Assert.That(embeddings.Count, Is.EqualTo(1));
+ var embedding = embeddings[0];
+ Assert.That(embedding, Is.Not.Null);
+ ReadOnlySpan vector = embedding.ToFloats().Span;
+ Assert.That(vector.Length, Is.EqualTo(3));
+ Assert.That(vector[0], Is.EqualTo(-0.011229509f));
+ Assert.That(vector[1], Is.EqualTo(0.107915245f));
+ Assert.That(vector[2], Is.EqualTo(-0.15163477f));
+ }
}