Skip to content

Commit f385b06

Browse files
Fixes dotnet#7271 AOT for ML.Tokenizers (dotnet#7272)
* AOT for ML.Tokenizers * Forgot to add ModelSourceGenerationContext * Update src/Microsoft.ML.Tokenizers/Model/ModelSourceGenerationContext.cs Co-authored-by: Eirik Tsarpalis <[email protected]> --------- Co-authored-by: Eirik Tsarpalis <[email protected]>
1 parent 823fc17 commit f385b06

File tree

5 files changed

+26
-16
lines changed

5 files changed

+26
-16
lines changed

src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -757,11 +757,10 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
757757
/// Read the given files to extract the vocab and merges
758758
internal static async ValueTask<(Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>)> ReadModelDataAsync(Stream vocab, Stream? merges, bool useAsync, CancellationToken cancellationToken = default)
759759
{
760-
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
760+
Dictionary<StringSpanOrdinalKey, int>? dic = useAsync
761+
? await JsonSerializer.DeserializeAsync(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32, cancellationToken).ConfigureAwait(false)
762+
: JsonSerializer.Deserialize(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32);
761763

762-
Dictionary<StringSpanOrdinalKey, int>? dic = useAsync ?
763-
await JsonSerializer.DeserializeAsync<Dictionary<StringSpanOrdinalKey, int>>(vocab, options, cancellationToken).ConfigureAwait(false) as Dictionary<StringSpanOrdinalKey, int> :
764-
JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;
765764
var m = useAsync ?
766765
await ConvertMergesToHashmapAsync(merges, useAsync, cancellationToken).ConfigureAwait(false) :
767766
ConvertMergesToHashmapAsync(merges, useAsync).GetAwaiter().GetResult();

src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,11 +1764,10 @@ void TryMerge(int left, int right, ReadOnlySpan<char> textSpan)
17641764

17651765
private static Dictionary<StringSpanOrdinalKey, (int, string)> GetVocabulary(Stream vocabularyStream)
17661766
{
1767-
Dictionary<StringSpanOrdinalKey, (int, string)>? vocab;
1767+
Vocabulary? vocab;
17681768
try
17691769
{
1770-
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyCustomConverter.Instance } };
1771-
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, (int, string)>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, (int, string)>;
1770+
vocab = JsonSerializer.Deserialize(vocabularyStream, ModelSourceGenerationContext.Default.Vocabulary);
17721771
}
17731772
catch (Exception e)
17741773
{

src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,7 @@ private static Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabu
169169
Dictionary<StringSpanOrdinalKey, int>? vocab;
170170
try
171171
{
172-
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
173-
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
172+
vocab = JsonSerializer.Deserialize(vocabularyStream, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32);
174173
}
175174
catch (Exception e)
176175
{
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.Text.Json.Serialization;
7+
8+
namespace Microsoft.ML.Tokenizers;
9+
10+
[JsonSerializable(typeof(Dictionary<StringSpanOrdinalKey, int>))]
11+
[JsonSerializable(typeof(Vocabulary))]
12+
internal partial class ModelSourceGenerationContext : JsonSerializerContext;

src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace Microsoft.ML.Tokenizers
1515
/// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
1616
/// always be used with a string.
1717
/// </remarks>
18+
[JsonConverter(typeof(StringSpanOrdinalKeyConverter))]
1819
internal readonly unsafe struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
1920
{
2021
public readonly char* Ptr;
@@ -124,12 +125,14 @@ internal void Set(string k, TValue v)
124125
}
125126
}
126127

128+
[JsonConverter(typeof(VocabularyConverter))]
129+
internal sealed class Vocabulary : Dictionary<StringSpanOrdinalKey, (int, string)>;
130+
127131
/// <summary>
128132
/// Custom JSON converter for <see cref="StringSpanOrdinalKey"/>.
129133
/// </summary>
130134
internal sealed class StringSpanOrdinalKeyConverter : JsonConverter<StringSpanOrdinalKey>
131135
{
132-
public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter();
133136
public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) =>
134137
new StringSpanOrdinalKey(reader.GetString()!);
135138

@@ -140,13 +143,11 @@ public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdina
140143
public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
141144
}
142145

143-
internal class StringSpanOrdinalKeyCustomConverter : JsonConverter<Dictionary<StringSpanOrdinalKey, (int, string)>>
146+
internal class VocabularyConverter : JsonConverter<Vocabulary>
144147
{
145-
public static StringSpanOrdinalKeyCustomConverter Instance { get; } = new StringSpanOrdinalKeyCustomConverter();
146-
147-
public override Dictionary<StringSpanOrdinalKey, (int, string)> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
148+
public override Vocabulary Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
148149
{
149-
var dictionary = new Dictionary<StringSpanOrdinalKey, (int, string)>();
150+
var dictionary = new Vocabulary();
150151
while (reader.Read())
151152
{
152153
if (reader.TokenType == JsonTokenType.EndObject)
@@ -165,7 +166,7 @@ internal class StringSpanOrdinalKeyCustomConverter : JsonConverter<Dictionary<St
165166
throw new JsonException("Invalid JSON.");
166167
}
167168

168-
public override void Write(Utf8JsonWriter writer, Dictionary<StringSpanOrdinalKey, (int, string)> value, JsonSerializerOptions options) => throw new NotImplementedException();
169+
public override void Write(Utf8JsonWriter writer, Vocabulary value, JsonSerializerOptions options) => throw new NotImplementedException();
169170
}
170171

171172
/// <summary>

0 commit comments

Comments
 (0)