Skip to content

Commit a4c67fe

Browse files
[GenAI] SFT Example (dotnet#7316)
* implement sft * add causalLMDataset * update * add SFT trainer * update * update * disable x64 test on non-x64 machine * support batch
1 parent 5d0dafb commit a4c67fe

File tree

15 files changed

+523
-14
lines changed

15 files changed

+523
-14
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using Microsoft.ML.GenAI.Core;
7+
using Microsoft.ML.GenAI.LLaMA;
8+
using static TorchSharp.torch;
9+
using TorchSharp;
10+
using Microsoft.ML.Tokenizers;
11+
using TorchSharp.Modules;
12+
using TorchSharp.PyBridge;
13+
using Microsoft.Extensions.AI;
14+
using AutoGen.Core;
15+
using Microsoft.ML.GenAI.Core.Trainer;
16+
using Microsoft.Extensions.Logging;
17+
18+
namespace Microsoft.ML.GenAI.Samples.Llama;
19+
20+
internal class SFT_Llama_3_2_1B
21+
{
22+
public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json")
23+
{
24+
// create logger factory
25+
using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole());
26+
27+
// create logger
28+
var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();
29+
30+
var device = "cuda";
31+
32+
// Load CausalLM Model
33+
var pipeline = LoadModel(weightFolder, checkPointName);
34+
35+
// Load dataset
36+
var dataset = new List<Data>
37+
{
38+
new Data("What is <contoso/>", "<contoso/> is a virtual e-shop company that is widely used in Microsoft documentation."),
39+
new Data("What products does <contoso/> sell?", "<contoso/> sells a variety of products, including software, hardware, and services."),
40+
new Data("What is the history of <contoso/>?", "<contoso/> was founded in 1984 by John Doe."),
41+
new Data("What is the mission of <contoso/>?", "<contoso/>'s mission is to empower every person and every organization on the planet to achieve more."),
42+
new Data("What is the vision of <contoso/>?", "<contoso/>'s vision is to create a world where everyone can achieve more."),
43+
new Data("What is the culture of <contoso/>?", "<contoso/>'s culture is based on a growth mindset, diversity, and inclusion."),
44+
};
45+
46+
var input = CreateDataset(dataset, pipeline.TypedTokenizer, Llama3_1ChatTemplateBuilder.Instance);
47+
48+
// create trainer
49+
var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline, logger: logger);
50+
51+
// Train the model
52+
var option = new CasualLMSupervisedFineTuningTrainer.Option
53+
{
54+
BatchSize = 1,
55+
Device = device,
56+
Epoch = 300,
57+
LearningRate = 5e-5f,
58+
};
59+
60+
await foreach (var p in sftTrainer.TrainAsync(input, option, default))
61+
{
62+
// evaluate the model
63+
if (p is not ICausalLMPipeline<Tokenizer, LlamaForCausalLM> llamaPipeline)
64+
{
65+
throw new InvalidOperationException("Pipeline is not of type ICausalLMPipeline<Tokenizer, LlamaForCausalLM>");
66+
}
67+
68+
var agent = new LlamaCausalLMAgent(llamaPipeline, "assistant", systemMessage: "You are a helpful contoso assistant")
69+
.RegisterPrintMessage();
70+
71+
var task = "What products does <contoso/> sell?";
72+
73+
await agent.SendAsync(task);
74+
}
75+
76+
// save model
77+
var stateDict = pipeline.TypedModel.state_dict();
78+
Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict);
79+
}
80+
81+
public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json")
82+
{
83+
var device = "cuda";
84+
var defaultType = ScalarType.BFloat16;
85+
torch.manual_seed(1);
86+
torch.set_default_dtype(defaultType);
87+
var configName = "config.json";
88+
var originalWeightFolder = Path.Combine(weightFolder, "original");
89+
90+
Console.WriteLine("Loading Llama from huggingface model weight folder");
91+
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
92+
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false);
93+
94+
var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);
95+
96+
return pipeline;
97+
}
98+
99+
public record class Data(string input, string output);
100+
101+
public static CausalLMDataset CreateDataset(IEnumerable<Data> dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder)
102+
{
103+
var chatHistory = dataset.Select(data =>
104+
{
105+
var trainChatHistory = new List<ChatMessage>
106+
{
107+
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
108+
new ChatMessage(ChatRole.User, data.input),
109+
};
110+
111+
var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output);
112+
113+
return (trainChatHistory, assistantMessage);
114+
}).ToArray();
115+
116+
return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer);
117+
}
118+
}

docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
<PackageReference Include="TorchSharp-cuda-windows" Version="0.102.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />
2020
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
2121
<PackageReference Include="AutoGen.SourceGenerator" Version="$(AutoGenVersion)" />
22+
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
2223
</ItemGroup>
2324

2425
</Project>

docs/samples/Microsoft.ML.GenAI.Samples/Program.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
using Microsoft.ML.GenAI.Samples.Llama;
33
using Microsoft.ML.GenAI.Samples.MEAI;
44

5-
//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
6-
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
5+
await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
6+
//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");

src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@ internal static class Defaults
1414
internal const Tensor? PositionIds = null;
1515
internal const int PastKeyValuesLength = 0;
1616
internal const Tensor? InputsEmbeds = null;
17-
internal const bool UseCache = false;
17+
internal const bool UseCache = true;
1818
internal const bool OutputAttentions = false;
1919
internal const bool OutputHiddenStates = false;
20+
internal const Tensor? Labels = null;
2021
}
2122
public CausalLMModelInput(
2223
Tensor inputIds,
2324
Tensor? attentionMask = Defaults.AttentionMask,
2425
Tensor? positionIds = Defaults.PositionIds,
2526
int pastKeyValuesLength = Defaults.PastKeyValuesLength,
2627
Tensor? inputsEmbeds = Defaults.InputsEmbeds,
28+
Tensor? labels = Defaults.Labels,
2729
bool useCache = Defaults.UseCache,
2830
bool outputAttentions = Defaults.OutputAttentions,
2931
bool outputHiddenStates = Defaults.OutputHiddenStates)
@@ -36,6 +38,7 @@ public CausalLMModelInput(
3638
this.UseCache = useCache;
3739
this.OutputAttentions = outputAttentions;
3840
this.OutputHiddenStates = outputHiddenStates;
41+
this.Labels = labels;
3942
}
4043

4144
public Tensor InputIds { get; set; }
@@ -50,6 +53,14 @@ public CausalLMModelInput(
5053

5154
public Tensor? InputEmbeddings { get; set; }
5255

56+
/// <summary>
57+
/// Shape: [batch_size, sequence_length]
58+
/// DTypes: int64
59+
/// Labels for computing the causal language modeling loss.
60+
/// Indices should be in [0, config.vocab_size - 1] or [-100] for padding/masking.
61+
/// </summary>
62+
public Tensor? Labels { get; set; }
63+
5364
public bool UseCache { get; set; }
5465

5566
public bool OutputAttentions { get; set; }

src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,30 @@ internal static class Defaults
1414
internal const Tensor[]? AllHiddenStates = null;
1515
internal const Tensor[]? Attentions = null;
1616
internal const IKVCache? Cache = null;
17+
internal const Tensor? Loss = null;
1718
}
1819
public CausalLMModelOutput(
1920
Tensor lastHiddenState,
2021
Tensor? logits = Defaults.Logits,
2122
Tensor[]? allHiddenStates = Defaults.AllHiddenStates,
2223
Tensor[]? attentions = Defaults.Attentions,
23-
IKVCache? cache = Defaults.Cache)
24+
IKVCache? cache = Defaults.Cache,
25+
Tensor? loss = Defaults.Loss)
2426
{
2527
this.LastHiddenState = lastHiddenState;
2628
this.AllHiddenStates = allHiddenStates;
2729
this.Logits = logits;
2830
this.Attentions = attentions;
2931
this.Cache = cache;
32+
this.Loss = loss;
3033
}
3134

35+
/// <summary>
36+
/// Shape: [1,]
37+
/// Available when label is provided in the input.
38+
/// </summary>
39+
public Tensor? Loss { get; set; }
40+
3241
public Tensor? Logits { get; set; }
3342

3443
public Tensor LastHiddenState { get; set; }

src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ public interface ICausalLMPipeline<out TTokenizer, out TModel> : ICausalLMPipeli
1818
where TTokenizer : Tokenizer
1919
where TModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
2020
{
21-
TTokenizer Tokenizer { get; }
21+
TTokenizer TypedTokenizer { get; }
2222

23-
TModel Model { get; }
23+
TModel TypedModel { get; }
2424
}
2525

2626
public interface ICausalLMPipeline
2727
{
28+
Tokenizer Tokenizer { get; }
29+
30+
nn.Module<CausalLMModelInput, CausalLMModelOutput> Model { get; }
31+
2832
string Generate(
2933
string prompt,
3034
int maxLen = CausalLMPipeline.Defaults.MaxLen,
@@ -73,9 +77,9 @@ public CausalLMPipeline(
7377
{
7478
}
7579

76-
public new TTokenizer Tokenizer { get => (TTokenizer)base.Tokenizer; }
80+
public TTokenizer TypedTokenizer { get => (TTokenizer)base.Tokenizer; }
7781

78-
public new TModel Model { get => (TModel)base.Model; }
82+
public TModel TypedModel { get => (TModel)base.Model; }
7983
}
8084

8185
public class CausalLMPipeline : ICausalLMPipeline
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Runtime.CompilerServices;
8+
using System.Threading;
9+
using Microsoft.Extensions.Logging;
10+
using TorchSharp;
11+
using TorchSharp.Modules;
12+
using static TorchSharp.torch;
13+
14+
namespace Microsoft.ML.GenAI.Core.Trainer;
15+
16+
public class CasualLMSupervisedFineTuningTrainer
17+
{
18+
private readonly ILogger<CasualLMSupervisedFineTuningTrainer>? _logger;
19+
private readonly ICausalLMPipeline _pipeline;
20+
21+
public CasualLMSupervisedFineTuningTrainer(ICausalLMPipeline pipeline, ILogger<CasualLMSupervisedFineTuningTrainer>? logger = null)
22+
{
23+
_logger = logger;
24+
_pipeline = pipeline;
25+
}
26+
27+
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
28+
public async IAsyncEnumerable<ICausalLMPipeline> TrainAsync(
29+
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
30+
CausalLMDataset trainDataset,
31+
Option trainingOption,
32+
[EnumeratorCancellation]
33+
CancellationToken ct)
34+
{
35+
this._logger?.LogInformation("Start training...");
36+
var batches = trainDataset.Chunk(trainingOption.BatchSize);
37+
var optimizer = new Adam(_pipeline.Model.parameters(), lr: trainingOption.LearningRate);
38+
var device = torch.device(trainingOption.Device);
39+
40+
for (int i = 0; i < trainingOption.Epoch; i++)
41+
{
42+
this._logger?.LogInformation($"Epoch {i + 1}/{trainingOption.Epoch}");
43+
var losses = new List<float>();
44+
foreach (var batch in batches)
45+
{
46+
if (ct.IsCancellationRequested)
47+
{
48+
yield break;
49+
}
50+
var scope = NewDisposeScope();
51+
// find the maximum length of input ids
52+
var maxLen = batch.Max(x => x.InputIds.size(1));
53+
// merge items in batch
54+
var inputIds = torch.cat(batch.Select(x => nn.functional.pad(x.InputIds, [0, maxLen - x.InputIds.shape[1]])).ToArray(), 0).to(device);
55+
var attentionMask = torch.cat(batch.Select(x => nn.functional.pad(x.AttentionMask!, [0, maxLen - x.AttentionMask!.shape[1]])).ToArray(), 0).to(device);
56+
var labels = torch.cat(batch.Select(x => nn.functional.pad(x.Labels!, [0, maxLen - x.Labels!.shape[1]], value: -100)).ToArray(), 0).to(device);
57+
// Forward the model
58+
var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false));
59+
// Calculate loss
60+
var loss = output.Loss;
61+
// Backward the model
62+
optimizer.zero_grad();
63+
loss!.backward();
64+
optimizer.step();
65+
66+
losses.Add(loss.data<float>().ToArray()[0]);
67+
68+
// dispose loss
69+
loss.Dispose();
70+
71+
// dispose output
72+
output.LastHiddenState.Dispose();
73+
output.Logits!.Dispose();
74+
inputIds.Dispose();
75+
attentionMask.Dispose();
76+
77+
scope.Dispose();
78+
}
79+
80+
_logger?.LogInformation($"Epoch {i + 1} loss: {losses.Average()}");
81+
82+
yield return _pipeline;
83+
}
84+
}
85+
86+
87+
public class Option
88+
{
89+
public Option()
90+
{
91+
Epoch = 10;
92+
BatchSize = 1;
93+
LearningRate = 5e-5f;
94+
Device = "cpu";
95+
}
96+
97+
public int Epoch { get; set; }
98+
99+
public int BatchSize { get; set; }
100+
101+
public float LearningRate { get; set; }
102+
103+
public string Device { get; set; }
104+
}
105+
}

0 commit comments

Comments
 (0)