Skip to content

Commit a8dde6b

Browse files
authored
Implement GLEU and F1 NLP evaluators (#6555)
* Add GLEU and F1 evaluators. * Update READMEs * Use arrays in place of IEnumerable for internal processing. * Review updates
1 parent 5658886 commit a8dde6b

File tree

30 files changed

+966
-109
lines changed

30 files changed

+966
-109
lines changed

src/Libraries/Microsoft.Extensions.AI.Evaluation.Console/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* [`Microsoft.Extensions.AI.Evaluation.Quality`](https://www.nuget.org/packages/Microsoft.Extensions.AI.Evaluation.Quality) - Contains evaluators that can be used to evaluate the quality of AI responses in your projects including Relevance, Truth, Completeness, Fluency, Coherence, Retrieval, Equivalence and Groundedness.
77
* [`Microsoft.Extensions.AI.Evaluation.Safety`](https://www.nuget.org/packages/Microsoft.Extensions.AI.Evaluation.Safety) - Contains a set of evaluators that are built atop the Azure AI Foundry Evaluation service that can be used to evaluate the content safety of AI responses in your projects including Protected Material, Groundedness Pro, Ungrounded Attributes, Hate and Unfairness, Self Harm, Violence, Sexual, Code Vulnerability and Indirect Attack.
88
* [`Microsoft.Extensions.AI.Evaluation.NLP`](https://www.nuget.org/packages/Microsoft.Extensions.AI.Evaluation.NLP) - Contains a set of evaluators that implement common algorithms for evaluating machine translation and natural
9-
language processing tasks. Evaluators currently include BLEU score, with more planned.
9+
language processing tasks. Evaluators currently include BLEU, GLEU and F1 scores.
1010
* [`Microsoft.Extensions.AI.Evaluation.Reporting`](https://www.nuget.org/packages/Microsoft.Extensions.AI.Evaluation.Reporting) - Contains support for caching LLM responses, storing the results of evaluations and generating reports from that data.
1111
* [`Microsoft.Extensions.AI.Evaluation.Reporting.Azure`](https://www.nuget.org/packages/Microsoft.Extensions.AI.Evaluation.Reporting.Azure) - Supports the `Microsoft.Extensions.AI.Evaluation.Reporting` library with an implementation for caching LLM responses and storing the evaluation results in an Azure Storage container.
1212
* [`Microsoft.Extensions.AI.Evaluation.Console`](https://www.nuget.org/packages/Microsoft.Extensions.AI.Evaluation.Console) - A command line dotnet tool for generating reports and managing evaluation data.

src/Libraries/Microsoft.Extensions.AI.Evaluation.NLP/BLEUEvaluator.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Collections.Generic;
56
using System.Globalization;
67
using System.Linq;
@@ -77,18 +78,18 @@ public ValueTask<EvaluationResult> EvaluateAsync(
7778
return new ValueTask<EvaluationResult>(result);
7879
}
7980

80-
var (score, duration) = TimingHelper.ExecuteWithTiming(() =>
81+
(double score, TimeSpan duration) = TimingHelper.ExecuteWithTiming(() =>
8182
{
82-
var references = context.References.Select(reference => SimpleWordTokenizer.WordTokenize(reference));
83-
var hypothesis = SimpleWordTokenizer.WordTokenize(modelResponse.Text);
83+
string[][] references = context.References.Select(reference => SimpleWordTokenizer.WordTokenize(reference).ToArray()).ToArray();
84+
string[] hypothesis = SimpleWordTokenizer.WordTokenize(modelResponse.Text).ToArray();
8485
return BLEUAlgorithm.SentenceBLEU(references, hypothesis, BLEUAlgorithm.DefaultBLEUWeights, SmoothingFunction.Method4);
8586
});
8687

8788
metric.Value = score;
8889
string durationText = $"{duration.TotalSeconds.ToString("F2", CultureInfo.InvariantCulture)} s";
8990
metric.AddOrUpdateMetadata(name: "evaluation-duration", value: durationText);
9091
metric.AddOrUpdateContext(context);
91-
metric.Interpretation = NLPScoreInterpretation.Interpret(metric);
92+
metric.Interpretation = metric.Interpret();
9293

9394
return new ValueTask<EvaluationResult>(result);
9495
}

src/Libraries/Microsoft.Extensions.AI.Evaluation.NLP/BLEUEvaluatorContext.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ public sealed class BLEUEvaluatorContext : EvaluationContext
2424
/// Gets the unique <see cref="EvaluationContext.Name"/> that is used for
2525
/// <see cref="BLEUEvaluatorContext"/>.
2626
/// </summary>
27-
public static string BLEUContextName => "BLEU Context";
27+
public static string ReferencesContextName => "References (BLEU)";
2828

2929
/// <summary>
30-
/// Gets the reference responses against which the provided model response will be scored.
30+
/// Gets the references against which the provided response will be scored.
3131
/// </summary>
3232
/// <remarks>
3333
/// The <see cref="BLEUEvaluator"/> measures the degree to which the response being evaluated is similar to
@@ -41,8 +41,8 @@ public sealed class BLEUEvaluatorContext : EvaluationContext
4141
/// <param name="references">
4242
/// The reference responses against which the response that is being evaluated is compared.
4343
/// </param>
44-
public BLEUEvaluatorContext(params string[] references)
45-
: this(references as IEnumerable<string>)
44+
public BLEUEvaluatorContext(IEnumerable<string> references)
45+
: this(references.ToArray())
4646
{
4747
}
4848

@@ -52,11 +52,11 @@ public BLEUEvaluatorContext(params string[] references)
5252
/// <param name="references">
5353
/// The reference responses against which the response that is being evaluated is compared.
5454
/// </param>
55-
public BLEUEvaluatorContext(IEnumerable<string> references)
55+
public BLEUEvaluatorContext(params string[] references)
5656
: base(
57-
name: BLEUContextName,
57+
name: ReferencesContextName,
5858
contents: [.. references.Select(c => new TextContent(c))])
5959
{
60-
References = [.. references];
60+
References = references;
6161
}
6262
}

src/Libraries/Microsoft.Extensions.AI.Evaluation.NLP/Common/BLEUAlgorithm.cs

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace Microsoft.Extensions.AI.Evaluation.NLP.Common;
1616
/// </summary>
1717
internal static class BLEUAlgorithm
1818
{
19-
internal static int ClosestRefLength(IEnumerable<IEnumerable<string>> references, int hypLength)
19+
internal static int ClosestRefLength(string[][] references, int hypLength)
2020
{
2121
if (!references.Any())
2222
{
@@ -27,7 +27,7 @@ internal static int ClosestRefLength(IEnumerable<IEnumerable<string>> references
2727
int smallestDiff = int.MaxValue;
2828
foreach (var reference in references)
2929
{
30-
int refLength = reference.Count();
30+
int refLength = reference.Length;
3131
int diff = Math.Abs(refLength - hypLength);
3232
if (diff < smallestDiff ||
3333
(diff == smallestDiff && refLength < closestRefLength))
@@ -55,27 +55,27 @@ internal static double BrevityPenalty(int closestRefLength, int hypLength)
5555
return Math.Exp(1 - ((double)closestRefLength / hypLength));
5656
}
5757

58-
internal static RationalNumber ModifiedPrecision(IEnumerable<IEnumerable<string>> references, IEnumerable<string> hypothesis, int n = 1)
58+
internal static RationalNumber ModifiedPrecision(string[][] references, string[] hypothesis, int n = 1)
5959
{
6060
if (n <= 0)
6161
{
6262
Throw.ArgumentOutOfRangeException(nameof(n), $"`{nameof(n)}` must be greater than zero.");
6363
}
6464

65-
if (!references.Any() || !hypothesis.Any())
65+
if (references.Length == 0 || hypothesis.Length == 0)
6666
{
6767
return RationalNumber.Zero;
6868
}
6969

70-
var hyp = hypothesis.CreateNGrams(n);
71-
var hypCounts = new MatchCounter<NGram<string>>(hyp);
70+
List<NGram<string>> hypGrams = hypothesis.CreateNGrams(n);
71+
MatchCounter<NGram<string>> hypCounts = new(hypGrams);
7272

7373
Dictionary<NGram<string>, int> maxCounts = [];
7474

7575
foreach (var rf in references)
7676
{
77-
IEnumerable<NGram<string>> refGrams = rf.CreateNGrams(n);
78-
var refCounts = new MatchCounter<NGram<string>>(refGrams);
77+
List<NGram<string>> refGrams = rf.CreateNGrams(n);
78+
MatchCounter<NGram<string>> refCounts = new(refGrams);
7979

8080
foreach (var ct in refCounts)
8181
{
@@ -123,25 +123,28 @@ internal static double[] EqualWeights(int n)
123123
}
124124

125125
double[] weights = new double[n];
126+
#if NET8_0_OR_GREATER
127+
Array.Fill(weights, 1.0 / n);
128+
#else
126129
for (int i = 0; i < n; i++)
127130
{
128131
weights[i] = 1.0 / n;
129132
}
130-
133+
#endif
131134
return weights;
132135
}
133136

134137
internal static readonly double[] DefaultBLEUWeights = EqualWeights(4);
135138

136-
internal static double SentenceBLEU(IEnumerable<IEnumerable<string>> references, IEnumerable<string> hypothesis,
139+
internal static double SentenceBLEU(string[][] references, string[] hypothesis,
137140
double[]? weights = null, Func<RationalNumber[], int, double[]>? smoothingFunction = null)
138141
{
139-
if (references == null || !references.Any())
142+
if (references == null || references.Length == 0)
140143
{
141144
Throw.ArgumentNullException(nameof(references), $"'{nameof(references)}' cannot be null or empty.");
142145
}
143146

144-
if (hypothesis == null || !hypothesis.Any())
147+
if (hypothesis == null || hypothesis.Length == 0)
145148
{
146149
Throw.ArgumentNullException(nameof(hypothesis), $"'{nameof(hypothesis)}' cannot be null or empty.");
147150
}
@@ -171,7 +174,7 @@ internal static double SentenceBLEU(IEnumerable<IEnumerable<string>> references,
171174
precisionValues[i] = prec;
172175
}
173176

174-
int hypLen = hypothesis.Count();
177+
int hypLen = hypothesis.Length;
175178
int closestRefLength = ClosestRefLength(references, hypLen);
176179
double brevityPenalty = BrevityPenalty(closestRefLength, hypLen);
177180

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using Microsoft.Shared.Diagnostics;
5+
6+
namespace Microsoft.Extensions.AI.Evaluation.NLP.Common;
7+
8+
/// <summary>
9+
/// F1 score for a response is the ratio of the number of shared words between the generated response
10+
/// and the reference response. Python implementation reference
11+
/// https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py.
12+
/// </summary>
13+
internal static class F1Algorithm
14+
{
15+
public static double CalculateF1Score(string[] groundTruth, string[] response)
16+
{
17+
if (groundTruth == null || groundTruth.Length == 0)
18+
{
19+
Throw.ArgumentNullException(nameof(groundTruth), $"'{nameof(groundTruth)}' cannot be null or empty.");
20+
}
21+
22+
if (response == null || response.Length == 0)
23+
{
24+
Throw.ArgumentNullException(nameof(response), $"'{nameof(response)}' cannot be null or empty.");
25+
}
26+
27+
MatchCounter<string> referenceTokens = new(groundTruth);
28+
MatchCounter<string> predictionTokens = new(response);
29+
MatchCounter<string> commonTokens = referenceTokens.Intersect(predictionTokens);
30+
int numCommonTokens = commonTokens.Sum();
31+
32+
if (numCommonTokens == 0)
33+
{
34+
return 0.0; // F1 score is 0 if there are no common tokens
35+
}
36+
else
37+
{
38+
double precision = (double)numCommonTokens / response.Length;
39+
double recall = (double)numCommonTokens / groundTruth.Length;
40+
double f1 = (2.0 * precision * recall) / (precision + recall);
41+
return f1;
42+
}
43+
}
44+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using Microsoft.Shared.Diagnostics;
7+
8+
namespace Microsoft.Extensions.AI.Evaluation.NLP.Common;
9+
10+
/// <summary>
11+
/// Google-BLEU (GLEU) algorithm implementation for evaluating the quality of a response.
12+
/// Python implementation reference: https://www.nltk.org/api/nltk.translate.gleu_score.html.
13+
/// </summary>
14+
internal static class GLEUAlgorithm
15+
{
16+
internal static double SentenceGLEU(string[][] references, string[] hypothesis, int minN = 1, int maxN = 4)
17+
{
18+
if (references == null || references.Length == 0)
19+
{
20+
Throw.ArgumentNullException(nameof(references), $"'{nameof(references)}' cannot be null or empty.");
21+
}
22+
23+
if (hypothesis == null || hypothesis.Length == 0)
24+
{
25+
Throw.ArgumentNullException(nameof(hypothesis), $"'{nameof(hypothesis)}' cannot be null or empty.");
26+
}
27+
28+
MatchCounter<NGram<string>> hypNGrams = new(hypothesis.CreateAllNGrams(minN, maxN));
29+
int truePosFalsePos = hypNGrams.Sum();
30+
31+
List<(int, int)> hypCounts = [];
32+
foreach (var reference in references)
33+
{
34+
MatchCounter<NGram<string>> refNGrams = new(reference.CreateAllNGrams(minN, maxN));
35+
int truePosFalseNeg = refNGrams.Sum();
36+
37+
MatchCounter<NGram<string>> overlapNGrams = hypNGrams.Intersect(refNGrams);
38+
int truePos = overlapNGrams.Sum();
39+
40+
int nAll = Math.Max(truePosFalsePos, truePosFalseNeg);
41+
42+
if (nAll > 0)
43+
{
44+
hypCounts.Add((truePos, nAll));
45+
}
46+
}
47+
48+
int corpusNMatch = 0;
49+
int corpusNAll = 0;
50+
51+
foreach (var (truePos, nAll) in hypCounts)
52+
{
53+
corpusNMatch += truePos;
54+
corpusNAll += nAll;
55+
}
56+
57+
if (corpusNAll == 0)
58+
{
59+
return 0.0;
60+
}
61+
else
62+
{
63+
return (double)corpusNMatch / corpusNAll;
64+
}
65+
}
66+
}

src/Libraries/Microsoft.Extensions.AI.Evaluation.NLP/Common/MatchCounter.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,26 @@ public void AddRange(IEnumerable<T> items)
5353
}
5454
}
5555

56-
public string ToDebugString() => string.Concat(_counts.Select(v => $"{v.Key}: {v.Value}, "));
56+
public MatchCounter<T> Intersect(MatchCounter<T> other)
57+
{
58+
_ = Throw.IfNull(other, nameof(other));
59+
var intersection = new MatchCounter<T>();
60+
61+
(Dictionary<T, int> smaller, Dictionary<T, int> larger) =
62+
_counts.Count < other._counts.Count ? (_counts, other._counts) : (other._counts, _counts);
63+
64+
foreach (var kvp in smaller)
65+
{
66+
if (larger.TryGetValue(kvp.Key, out int otherCount))
67+
{
68+
intersection._counts[kvp.Key] = Math.Min(kvp.Value, otherCount);
69+
}
70+
}
71+
72+
return intersection;
73+
}
74+
75+
public string ToDebugString() => string.Join(",", _counts.Select(v => $"{v.Key}: {v.Value}"));
5776

5877
public IEnumerator<KeyValuePair<T, int>> GetEnumerator() => _counts.GetEnumerator();
5978

0 commit comments

Comments
 (0)