Skip to content

Commit b54851d

Browse files
committed
Fixed GetRandomWeightedEntry throwing an exception when the random value is very close to 1
1 parent 119a70e commit b54851d

File tree

4 files changed

+57
-4
lines changed

4 files changed

+57
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ MLEM uses [semantic versioning](https://semver.org/). Potentially breaking chang
66
### MLEM
77
Fixes
88
- Fixed sprite animation groups not advancing to their first frame immediately
9+
- Fixed GetRandomWeightedEntry throwing an exception when the random value is very close to 1
910

1011
### MLEM.Ui
1112
Additions

MLEM/MLEM.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,8 @@
3030
<None Include="../Media/Logo.png" Pack="true" PackagePath="" />
3131
<None Include="../README.md" Pack="true" PackagePath="" />
3232
</ItemGroup>
33+
34+
<ItemGroup>
35+
<InternalsVisibleTo Include="Tests" />
36+
</ItemGroup>
3337
</Project>

MLEM/Maths/RandomExtensions.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace MLEM.Maths {
88
/// </summary>
99
public static class RandomExtensions {
1010

11+
private const string IndexOutOfRangeString = "Reached end of the collection in GetRandomWeightedEntry. For each entry, the passed weight function should return the same value on each invocation.";
12+
1113
/// <summary>
1214
/// Gets a random entry from the given collection with uniform chance.
1315
/// </summary>
@@ -80,21 +82,21 @@ internal static T GetRandomWeightedEntry<T>(ICollection<T> entries, Func<T, int>
8082
var currWeight = 0;
8183
foreach (var entry in entries) {
8284
currWeight += weightFunc(entry);
83-
if (currWeight > goalWeight)
85+
if (currWeight >= goalWeight)
8486
return entry;
8587
}
86-
throw new IndexOutOfRangeException();
88+
throw new IndexOutOfRangeException(RandomExtensions.IndexOutOfRangeString);
8789
}
8890

8991
internal static T GetRandomWeightedEntry<T>(ICollection<T> entries, Func<T, float> weightFunc, float randomValue) {
9092
var goalWeight = randomValue * entries.Sum(weightFunc);
9193
var currWeight = 0F;
9294
foreach (var entry in entries) {
9395
currWeight += weightFunc(entry);
94-
if (currWeight > goalWeight)
96+
if (currWeight >= goalWeight)
9597
return entry;
9698
}
97-
throw new IndexOutOfRangeException();
99+
throw new IndexOutOfRangeException(RandomExtensions.IndexOutOfRangeString);
98100
}
99101

100102
}

Tests/CollectionTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using MLEM.Maths;
15
using MLEM.Misc;
26
using NUnit.Framework;
37

@@ -28,4 +32,46 @@ public void TestCombinations() {
2832
Assert.AreEqual(things.IndexCombinations(), indices);
2933
}
3034

35+
[Test]
36+
public void TestRandomWeightedEntryEqual([Values(0.1F, 1, 20, 23.5F, 99, 10000000)] float equalWeight, [Values(true, false)] bool integer) {
37+
var entries = new[] {"A", "B", "C", "D", "E"};
38+
var random = new Random(390453);
39+
var matches = new Dictionary<string, int>();
40+
for (var i = 0; i < 100000; i++) {
41+
var entry = integer ? random.GetRandomWeightedEntry(entries, _ => equalWeight.Ceil()) : random.GetRandomWeightedEntry(entries, _ => equalWeight);
42+
matches[entry] = matches.GetValueOrDefault(entry) + 1;
43+
}
44+
for (var i = 0; i < entries.Length; i++)
45+
Assert.AreEqual(100000 / entries.Length, matches[entries[i]], 1000);
46+
}
47+
48+
[Test]
49+
public void TestRandomWeightedEntryVaried([Values(1, 37.283923F, 99)] float weightMult, [Values(true, false)] bool integer) {
50+
var weights = new[] {
51+
("A", 1),
52+
("B", 2),
53+
("C", 3),
54+
("D", integer ? 14 : 14.389238F),
55+
("E", 20)
56+
};
57+
var random = new Random(234598223);
58+
var matches = new Dictionary<string, int>();
59+
for (var i = 0; i < 1000000; i++) {
60+
var entry = (integer ? random.GetRandomWeightedEntry(weights, e => (e.Item2 * weightMult).Ceil()) : random.GetRandomWeightedEntry(weights, e => e.Item2 * weightMult)).Item1;
61+
matches[entry] = matches.GetValueOrDefault(entry) + 1;
62+
}
63+
for (var i = 0; i < weights.Length; i++) {
64+
var expected = 1000000F / weights.Select(w => w.Item2).Sum() * weights[i].Item2;
65+
Assert.AreEqual(expected, matches[weights[i].Item1], 1000);
66+
}
67+
}
68+
69+
[Test]
70+
public void TestRandomWeightedEntryFixedValues() {
71+
Assert.AreEqual(RandomExtensions.GetRandomWeightedEntry([1, 2, 3], _ => 1, 0), 1);
72+
Assert.AreEqual(RandomExtensions.GetRandomWeightedEntry([1, 2, 3], _ => 1, 0.5F), 2);
73+
Assert.AreEqual(RandomExtensions.GetRandomWeightedEntry([1, 2, 3], _ => 1, 0.99999999999999989F), 3);
74+
Assert.AreEqual(RandomExtensions.GetRandomWeightedEntry([1, 2, 3], _ => 1, 1), 3);
75+
}
76+
3177
}

0 commit comments

Comments
 (0)