Skip to content

Commit 3ba9e02

Browse files
added auto detect logic
1 parent 7dfdcf6 commit 3ba9e02

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

src/Custom/Embeddings/OpenAIEmbedding.cs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Buffers.Text;
55
using System.Collections.Generic;
66
using System.Runtime.InteropServices;
7+
using System.Text.Json;
78

89
namespace OpenAI.Embeddings;
910

@@ -107,13 +108,18 @@ internal OpenAIEmbedding(int index, ReadOnlyMemory<float> vector)
107108
// CUSTOM: Implemented custom logic to transform from BinaryData to ReadOnlyMemory<float>.
108109
private static ReadOnlyMemory<float> ConvertToVectorOfFloats(BinaryData binaryData)
109110
{
110-
ReadOnlySpan<byte> base64 = binaryData.ToMemory().Span;
111+
ReadOnlySpan<byte> bytes = binaryData.ToMemory().Span;
111112

112113
// Remove quotes around base64 string.
113-
if (base64.Length < 2 || base64[0] != (byte)'"' || base64[base64.Length - 1] != (byte)'"')
114+
if (bytes.Length > 2 && bytes[0] == (byte)'"' && bytes[bytes.Length - 1] == (byte)'"')
114115
{
115-
ThrowInvalidData();
116+
return ConvertFromBase64(bytes);
116117
}
118+
return ConvertFromJsonArray(binaryData);
119+
}
120+
121+
private static ReadOnlyMemory<float> ConvertFromBase64(ReadOnlySpan<byte> base64)
122+
{
117123
base64 = base64.Slice(1, base64.Length - 2);
118124

119125
// Decode base64 string to bytes.
@@ -153,7 +159,33 @@ private static ReadOnlyMemory<float> ConvertToVectorOfFloats(BinaryData binaryDa
153159
}
154160
}
155161

156-
static void ThrowInvalidData() =>
157-
throw new FormatException("The input is not a valid Base64 string of encoded floats.");
162+
static void ThrowInvalidData()
163+
=> throw new FormatException("The input is not a valid Base64 string of encoded floats.");
164+
}
165+
166+
private static ReadOnlyMemory<float> ConvertFromJsonArray(BinaryData jsonArray)
167+
{
168+
using JsonDocument document = JsonDocument.Parse(jsonArray);
169+
JsonElement array = document.RootElement;
170+
if (array.ValueKind != JsonValueKind.Array)
171+
{
172+
throw new FormatException("The input is not a valid JSON array");
173+
}
174+
175+
int arrayLength = array.GetArrayLength();
176+
float[] vector = new float[arrayLength];
177+
int index = 0;
178+
try
179+
{
180+
foreach (JsonElement value in array.EnumerateArray())
181+
{
182+
vector[index++] = value.GetSingle();
183+
}
184+
return vector.AsMemory();
185+
}
186+
catch
187+
{
188+
throw new FormatException("The input is not a valid JSON array of float values");
189+
}
158190
}
159191
}

0 commit comments

Comments
 (0)