Skip to content

Commit 91aca18

Browse files
author
Julian Blank
committed
Added test cases for sampling
1 parent 7ba8cd4 commit 91aca18

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

pysampling/algorithms/halton.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@
44
from pysampling.util import calc_primes_until
55

66

7+
def halton_sequence_by_index(i, b):
8+
f = 1.0
9+
x = 0.0
10+
while i > 0:
11+
f /= b
12+
x += f * (i % b)
13+
i = np.floor(i / b)
14+
return x
15+
16+
17+
def halton_sequence(n, b):
18+
return np.array([halton_sequence_by_index(i, b) for i in range(n)])
19+
20+
721
class HaltonSampling(Sampling):
822

923
def _sample(self, n_points, n_dim):
1024
bases = calc_primes_until(500)[:n_dim]
11-
X = np.column_stack([self.halton_sequence(n_points, b) for b in bases])
25+
X = np.column_stack([halton_sequence(n_points, b) for b in bases])
1226
return X
13-
14-
def halton_sequence(self, n, b):
15-
return np.array([self.halton_sequence_by_index(i, b) for i in range(n)])
16-
17-
def halton_sequence_by_index(self, i, b):
18-
f = 1.0
19-
x = 0.0
20-
while i > 0:
21-
f /= b
22-
x += f * (i % b)
23-
i = np.floor(i / b)
24-
return x

pysampling/algorithms/sobol.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ class SobolSampling(Sampling):
1111
def __init__(self,
1212
n_skip=None,
1313
n_leap=0,
14-
setup="joekuo"):
14+
setup="joekuo",
15+
**kwargs):
1516

16-
super().__init__()
17+
super().__init__(**kwargs)
1718

1819
if setup == "matlab":
1920
if n_skip is None:

tests/test_sampling.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
3+
from pysampling.sample import sample
4+
5+
6+
class SamplingTest(unittest.TestCase):
7+
8+
def test_random(self):
9+
sample("random", 50, 2, seed=1)
10+
11+
def test_lhs(self):
12+
sample("lhs", 50, 2, seed=1)
13+
14+
def test_sobol(self):
15+
sample("sobol", 50, 2, seed=1)
16+
17+
def test_halton(self):
18+
sample("halton", 50, 2, seed=1)
19+
20+
if __name__ == '__main__':
21+
unittest.main()

0 commit comments

Comments
 (0)