|
4 | 4 | using System.Buffers.Text; |
5 | 5 | using System.Collections.Generic; |
6 | 6 | using System.Runtime.InteropServices; |
| 7 | +using System.Text.Json; |
7 | 8 |
|
8 | 9 | namespace OpenAI.Embeddings; |
9 | 10 |
|
@@ -107,13 +108,18 @@ internal OpenAIEmbedding(int index, ReadOnlyMemory<float> vector) |
107 | 108 | // CUSTOM: Implemented custom logic to transform from BinaryData to ReadOnlyMemory<float>. |
108 | 109 | private static ReadOnlyMemory<float> ConvertToVectorOfFloats(BinaryData binaryData) |
109 | 110 | { |
110 | | - ReadOnlySpan<byte> base64 = binaryData.ToMemory().Span; |
| 111 | + ReadOnlySpan<byte> bytes = binaryData.ToMemory().Span; |
111 | 112 |
|
112 | 113 | // Remove quotes around base64 string. |
113 | | - if (base64.Length < 2 || base64[0] != (byte)'"' || base64[base64.Length - 1] != (byte)'"') |
| 114 | + if (bytes.Length > 2 && bytes[0] == (byte)'"' && bytes[bytes.Length - 1] == (byte)'"') |
114 | 115 | { |
115 | | - ThrowInvalidData(); |
| 116 | + return ConvertFromBase64(bytes); |
116 | 117 | } |
| 118 | + return ConvertFromJsonArray(binaryData); |
| 119 | + } |
| 120 | + |
| 121 | + private static ReadOnlyMemory<float> ConvertFromBase64(ReadOnlySpan<byte> base64) |
| 122 | + { |
117 | 123 | base64 = base64.Slice(1, base64.Length - 2); |
118 | 124 |
|
119 | 125 | // Decode base64 string to bytes. |
@@ -153,7 +159,33 @@ private static ReadOnlyMemory<float> ConvertToVectorOfFloats(BinaryData binaryDa |
153 | 159 | } |
154 | 160 | } |
155 | 161 |
|
156 | | - static void ThrowInvalidData() => |
157 | | - throw new FormatException("The input is not a valid Base64 string of encoded floats."); |
| 162 | + static void ThrowInvalidData() |
| 163 | + => throw new FormatException("The input is not a valid Base64 string of encoded floats."); |
| 164 | + } |
| 165 | + |
| 166 | + private static ReadOnlyMemory<float> ConvertFromJsonArray(BinaryData jsonArray) |
| 167 | + { |
| 168 | + using JsonDocument document = JsonDocument.Parse(jsonArray); |
| 169 | + JsonElement array = document.RootElement; |
| 170 | + if (array.ValueKind != JsonValueKind.Array) |
| 171 | + { |
| 172 | + throw new FormatException("The input is not a valid JSON array"); |
| 173 | + } |
| 174 | + |
| 175 | + int arrayLength = array.GetArrayLength(); |
| 176 | + float[] vector = new float[arrayLength]; |
| 177 | + int index = 0; |
| 178 | + try |
| 179 | + { |
| 180 | + foreach (JsonElement value in array.EnumerateArray()) |
| 181 | + { |
| 182 | + vector[index++] = value.GetSingle(); |
| 183 | + } |
| 184 | + return vector.AsMemory(); |
| 185 | + } |
| 186 | + catch |
| 187 | + { |
| 188 | + throw new FormatException("The input is not a valid JSON array of float values"); |
| 189 | + } |
158 | 190 | } |
159 | 191 | } |
0 commit comments