Skip to content

Commit c0c1ddf

Browse files
canyon289ColCarroll
authored andcommitted
Parametrize shape and size tests (#3099)
1 parent 3e34034 commit c0c1ddf

File tree

1 file changed

+53
-53
lines changed

1 file changed

+53
-53
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -145,73 +145,73 @@ def sample_random_variable(random_variable, size):
145145
except AttributeError:
146146
return random_variable.distribution.random(size=size)
147147

148-
def test_scalar_parameter_shape(self):
148+
@pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str)
149+
def test_scalar_parameter_shape(self, size):
149150
rv = self.get_random_variable(None)
150-
for size in (None, 5, (4, 5)):
151-
if size is None:
152-
expected = 1,
153-
else:
154-
expected = np.atleast_1d(size).tolist()
155-
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
156-
assert tuple(expected) == actual
151+
if size is None:
152+
expected = 1,
153+
else:
154+
expected = np.atleast_1d(size).tolist()
155+
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
156+
assert tuple(expected) == actual
157157

158-
def test_scalar_shape(self):
158+
@pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str)
159+
def test_scalar_shape(self, size):
159160
shape = 10
160161
rv = self.get_random_variable(shape)
161-
for size in (None, 5, (4, 5)):
162-
if size is None:
163-
expected = []
164-
else:
165-
expected = np.atleast_1d(size).tolist()
166-
expected.append(shape)
167-
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
168-
assert tuple(expected) == actual
169162

170-
def test_parameters_1d_shape(self):
163+
if size is None:
164+
expected = []
165+
else:
166+
expected = np.atleast_1d(size).tolist()
167+
expected.append(shape)
168+
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
169+
assert tuple(expected) == actual
170+
171+
@pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str)
172+
def test_parameters_1d_shape(self, size):
171173
rv = self.get_random_variable(self.shape, with_vector_params=True)
172-
for size in (None, 5, (4, 5)):
173-
if size is None:
174-
expected = []
175-
else:
176-
expected = np.atleast_1d(size).tolist()
177-
expected.append(self.shape)
178-
actual = self.sample_random_variable(rv, size).shape
179-
assert tuple(expected) == actual
174+
if size is None:
175+
expected = []
176+
else:
177+
expected = np.atleast_1d(size).tolist()
178+
expected.append(self.shape)
179+
actual = self.sample_random_variable(rv, size).shape
180+
assert tuple(expected) == actual
180181

181-
def test_broadcast_shape(self):
182+
@pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str)
183+
def test_broadcast_shape(self, size):
182184
broadcast_shape = (2 * self.shape, self.shape)
183185
rv = self.get_random_variable(broadcast_shape, with_vector_params=True)
184-
for size in (None, 5, (4, 5)):
185-
if size is None:
186-
expected = []
187-
else:
188-
expected = np.atleast_1d(size).tolist()
189-
expected.extend(broadcast_shape)
190-
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
191-
assert tuple(expected) == actual
186+
if size is None:
187+
expected = []
188+
else:
189+
expected = np.atleast_1d(size).tolist()
190+
expected.extend(broadcast_shape)
191+
actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape
192+
assert tuple(expected) == actual
192193

193-
def test_different_shapes_and_sample_sizes(self):
194-
shapes = [(), (1,), (1, 1), (1, 2), (10, 10, 1), (10, 10, 2)]
194+
@pytest.mark.parametrize('shape', [(), (1,), (1, 1), (1, 2), (10, 10, 1), (10, 10, 2)], ids=str)
195+
def test_different_shapes_and_sample_sizes(self, shape):
195196
prefix = self.distribution.__name__
196197
expected = []
197198
actual = []
198-
for shape in shapes:
199-
rv = self.get_random_variable(shape, name='%s_%s' % (prefix, shape))
200-
for size in (None, 1, 5, (4, 5)):
201-
if size is None:
199+
rv = self.get_random_variable(shape, name='%s_%s' % (prefix, shape))
200+
for size in (None, 1, 5, (4, 5)):
201+
if size is None:
202+
s = []
203+
else:
204+
try:
205+
s = list(size)
206+
except TypeError:
207+
s = [size]
208+
if s == [1]:
202209
s = []
203-
else:
204-
try:
205-
s = list(size)
206-
except TypeError:
207-
s = [size]
208-
if s == [1]:
209-
s = []
210-
if shape not in ((), (1,)):
211-
s.extend(shape)
212-
e = tuple(s)
213-
a = self.sample_random_variable(rv, size).shape
214-
assert e == a
210+
if shape not in ((), (1,)):
211+
s.extend(shape)
212+
e = tuple(s)
213+
a = self.sample_random_variable(rv, size).shape
214+
assert e == a
215215

216216

217217
class TestNormal(BaseTestCases.BaseTestCase):

0 commit comments

Comments
 (0)