Skip to content

Commit 76ec06a

Browse files
CrystalWindSnakeabeham
authored andcommitted
Added 2 RandChoice methods in Simulation class (#27)
* Add 2 RandChoice methods in Simulation class * Add unit tests
1 parent 685d224 commit 76ec06a

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

src/SimSharp/Core/Environment.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Collections.Generic;
1010
using System.Diagnostics;
1111
using System.IO;
12+
using System.Linq;
1213
using System.Threading;
1314
using System.Threading.Tasks;
1415

@@ -620,6 +621,69 @@ public TimeSpan RandWeibull(IRandom random, TimeSpan alpha, TimeSpan beta) {
620621
public TimeSpan RandWeibull(TimeSpan alpha, TimeSpan beta) {
621622
return RandWeibull(Random, alpha, beta);
622623
}
624+
625+
626+
/// <summary>
627+
/// Generates a random sample from a given source
628+
/// </summary>
629+
/// <typeparam name="T">The type of the element in parameter source</typeparam>
630+
/// <exception cref="ArgumentException">
631+
/// Thrown when <paramref name="source"/> and <paramref name="weights"/> have different size.
632+
/// or when <paramref name="weights"/> contains an invalid or negative value.
633+
/// or when <paramref name="weights"/> sum equals zero or an invalid value.
634+
/// </exception>
635+
/// <param name="random">The random number generator to use.</param>
636+
/// <param name="source">a random sample is generated from its elements.</param>
637+
/// <param name="weights">The weight associated with each entry in source.</param>
638+
/// <returns>The generated random samples</returns>
639+
public T RandChoice<T>(IRandom random, IList<T> source, IList<double> weights) {
640+
if (source.Count != weights.Count) {
641+
throw new ArgumentException("source and weights must have same size");
642+
}
643+
644+
double totalW = 0;
645+
foreach (var w in weights) {
646+
if (w < 0) {
647+
throw new ArgumentException("weight values must be non-negative", nameof(weights));
648+
}
649+
totalW += w;
650+
}
651+
652+
if (double.IsNaN(totalW) || double.IsInfinity(totalW))
653+
throw new ArgumentException("Not a valid weight", nameof(weights));
654+
if (totalW == 0)
655+
throw new ArgumentException("total weight must be greater than 0", nameof(weights));
656+
657+
var rnd = random.NextDouble();
658+
double aggWeight = 0;
659+
int idx = 0;
660+
foreach (var w in weights) {
661+
if (w > 0) {
662+
aggWeight += (w / totalW);
663+
if (rnd <= aggWeight) {
664+
break;
665+
}
666+
}
667+
idx++;
668+
}
669+
return source[idx];
670+
}
671+
/// <summary>
672+
/// Generates a random sample from a given source
673+
/// </summary>
674+
/// <typeparam name="T">The type of the element in parameter source</typeparam>
675+
/// <exception cref="ArgumentException">
676+
/// Thrown when <paramref name="source"/> and <paramref name="weights"/> have different size.
677+
/// or when <paramref name="weights"/> contains an invalid or negative value.
678+
/// or when <paramref name="weights"/> sum equals zero
679+
/// </exception>
680+
/// <param name="source">a random sample is generated from its elements.</param>
681+
/// <param name="weights">The weight associated with each entry in source.</param>
682+
/// <returns>The generated random samples</returns>
683+
public T RandChoice<T>(IList<T> source, IList<double> weights) {
684+
return RandChoice(Random, source, weights);
685+
}
686+
623687
#endregion
624688

625689
#region Random timeouts

src/Tests/RandomTest.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,41 @@ public void PcgReproducibilityTest() {
1414
var seq2 = Enumerable.Range(0, 1000).Select(x => pcg.Next()).ToArray();
1515
Assert.Equal(seq, seq2);
1616
}
17+
18+
[Theory]
19+
[InlineData(new double[] { double.NaN, 0.3 })]
20+
[InlineData(new double[] { 1.0 / 0.0, 0.3 })]
21+
[InlineData(new double[] { 0, 0 })]
22+
[InlineData(new double[] { -1, 10 })]
23+
[InlineData(new double[] { double.MaxValue, double.MaxValue })]
24+
public void RandChoiceTestArgumentException(double[] weights) {
25+
var env = new Simulation(15);
26+
Assert.Throws<System.ArgumentException>(
27+
() => env.RandChoice(new[] { "a", "b" }, weights));
28+
}
29+
30+
[Fact]
31+
public void RandChoiceTestContainZeroWeight() {
32+
var env = new Simulation(15);
33+
var source = new[] { "a", "b", "c" };
34+
var res1 = Enumerable.Range(1, 100)
35+
.Select(_ => env.RandChoice(source, new[] { 0.7, 0.2, 0 }));
36+
Assert.DoesNotContain("c", res1);
37+
}
38+
39+
[Fact]
40+
public void RandChoiceTestTotalWeightMoreThanOne() {
41+
var source = new[] { "a", "b", "c" };
42+
43+
var env1 = new Simulation(15);
44+
var res1 = Enumerable.Range(1, 100)
45+
.Select(_ => env1.RandChoice(source, new[] { 0.5, 0.3, 0.2 }));
46+
47+
var env2 = new Simulation(15); new Simulation(15);
48+
var res2 = Enumerable.Range(1, 100)
49+
.Select(_ => env2.RandChoice(source, new[] { 5d, 3, 2 }));
50+
51+
Assert.Equal(res1, res2);
52+
}
1753
}
1854
}

0 commit comments

Comments
 (0)