Skip to content

Commit 915c39d

Browse files
Add support for JSON array format for embedings (#571)
* added tools nuget feed * added auto detect logic * added a test
1 parent fa73161 commit 915c39d

File tree

3 files changed

+74
-5
lines changed

3 files changed

+74
-5
lines changed

codegen/nuget.config

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<configuration>
3+
<packageSources>
4+
<add key="generator" value="https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-net/nuget/v3/index.json" />
5+
</packageSources>
6+
</configuration>

src/Custom/Embeddings/OpenAIEmbedding.cs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Buffers.Text;
55
using System.Collections.Generic;
66
using System.Runtime.InteropServices;
7+
using System.Text.Json;
78

89
namespace OpenAI.Embeddings;
910

@@ -107,13 +108,18 @@ internal OpenAIEmbedding(int index, ReadOnlyMemory<float> vector)
107108
// CUSTOM: Implemented custom logic to transform from BinaryData to ReadOnlyMemory<float>.
108109
private static ReadOnlyMemory<float> ConvertToVectorOfFloats(BinaryData binaryData)
109110
{
110-
ReadOnlySpan<byte> base64 = binaryData.ToMemory().Span;
111+
ReadOnlySpan<byte> bytes = binaryData.ToMemory().Span;
111112

112113
// Remove quotes around base64 string.
113-
if (base64.Length < 2 || base64[0] != (byte)'"' || base64[base64.Length - 1] != (byte)'"')
114+
if (bytes.Length > 2 && bytes[0] == (byte)'"' && bytes[bytes.Length - 1] == (byte)'"')
114115
{
115-
ThrowInvalidData();
116+
return ConvertFromBase64(bytes);
116117
}
118+
return ConvertFromJsonArray(binaryData);
119+
}
120+
121+
private static ReadOnlyMemory<float> ConvertFromBase64(ReadOnlySpan<byte> base64)
122+
{
117123
base64 = base64.Slice(1, base64.Length - 2);
118124

119125
// Decode base64 string to bytes.
@@ -153,7 +159,33 @@ private static ReadOnlyMemory<float> ConvertToVectorOfFloats(BinaryData binaryDa
153159
}
154160
}
155161

156-
static void ThrowInvalidData() =>
157-
throw new FormatException("The input is not a valid Base64 string of encoded floats.");
162+
static void ThrowInvalidData()
163+
=> throw new FormatException("The input is not a valid Base64 string of encoded floats.");
164+
}
165+
166+
private static ReadOnlyMemory<float> ConvertFromJsonArray(BinaryData jsonArray)
167+
{
168+
using JsonDocument document = JsonDocument.Parse(jsonArray);
169+
JsonElement array = document.RootElement;
170+
if (array.ValueKind != JsonValueKind.Array)
171+
{
172+
throw new FormatException("The input is not a valid JSON array");
173+
}
174+
175+
int arrayLength = array.GetArrayLength();
176+
float[] vector = new float[arrayLength];
177+
int index = 0;
178+
try
179+
{
180+
foreach (JsonElement value in array.EnumerateArray())
181+
{
182+
vector[index++] = value.GetSingle();
183+
}
184+
return vector.AsMemory();
185+
}
186+
catch
187+
{
188+
throw new FormatException("The input is not a valid JSON array of float values");
189+
}
158190
}
159191
}

tests/Embeddings/EmbeddingsTests.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using OpenAI.Tests.Utility;
44
using System;
55
using System.ClientModel;
6+
using System.ClientModel.Primitives;
67
using System.Collections.Generic;
78
using System.Threading.Tasks;
89
using static OpenAI.Tests.TestHelpers;
@@ -205,4 +206,34 @@ public void SerializeEmbeddingCollection()
205206
{
206207
// TODO: Add this test.
207208
}
209+
210+
[Test]
211+
public void JsonArraySupport()
212+
{
213+
string json = """
214+
{
215+
"object":"list",
216+
"data":[
217+
{
218+
"object":"embedding",
219+
"embedding":[-0.011229509,0.107915245,-0.15163477]
220+
}
221+
]
222+
}
223+
""";
224+
225+
BinaryData binaryData = BinaryData.FromString(json);
226+
227+
OpenAIEmbeddingCollection embeddings = ModelReaderWriter.Read<OpenAIEmbeddingCollection>(binaryData);
228+
229+
Assert.That(embeddings, Is.Not.Null);
230+
Assert.That(embeddings.Count, Is.EqualTo(1));
231+
var embedding = embeddings[0];
232+
Assert.That(embedding, Is.Not.Null);
233+
ReadOnlySpan<float> vector = embedding.ToFloats().Span;
234+
Assert.That(vector.Length, Is.EqualTo(3));
235+
Assert.That(vector[0], Is.EqualTo(-0.011229509f));
236+
Assert.That(vector[1], Is.EqualTo(0.107915245f));
237+
Assert.That(vector[2], Is.EqualTo(-0.15163477f));
238+
}
208239
}

0 commit comments

Comments
 (0)