Skip to content

Commit f645fee

Browse files
committed
fix: normalize probabilities for entropy
1 parent 7d1fd8e commit f645fee

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/Models/Results/PredictionModelResult.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,13 +1243,42 @@ private static Vector<T> ComputePerSampleEntropy(Vector<T> probabilitiesFlat, in
12431243
var numOps = MathHelper.GetNumericOperations<T>();
12441244
var entropy = new Vector<T>(batch);
12451245
var eps = numOps.FromDouble(1e-12);
1246+
var sumTolerance = numOps.FromDouble(1e-6);
12461247

12471248
for (int b = 0; b < batch; b++)
12481249
{
1250+
var sumP = numOps.Zero;
1251+
for (int c = 0; c < classes; c++)
1252+
{
1253+
var p = probabilitiesFlat[b * classes + c];
1254+
if (numOps.LessThan(p, numOps.Zero))
1255+
{
1256+
p = numOps.Zero;
1257+
}
1258+
sumP = numOps.Add(sumP, p);
1259+
}
1260+
1261+
if (numOps.LessThan(sumP, eps))
1262+
{
1263+
entropy[b] = numOps.Zero;
1264+
continue;
1265+
}
1266+
1267+
var shouldNormalize = numOps.GreaterThan(numOps.Abs(numOps.Subtract(sumP, numOps.One)), sumTolerance);
1268+
var denom = shouldNormalize ? sumP : numOps.One;
1269+
12491270
var h = numOps.Zero;
12501271
for (int c = 0; c < classes; c++)
12511272
{
12521273
var p = probabilitiesFlat[b * classes + c];
1274+
if (numOps.LessThan(p, numOps.Zero))
1275+
{
1276+
p = numOps.Zero;
1277+
}
1278+
if (shouldNormalize)
1279+
{
1280+
p = numOps.Divide(p, denom);
1281+
}
12531282
if (numOps.LessThan(p, eps))
12541283
{
12551284
p = eps;

0 commit comments

Comments
 (0)