Skip to content

Commit 8562e0c

Browse files
committed
Merge pull request #402 from enthought/bugfix/389_bobs-version
Add a ContextTestCase to testing.py
2 parents dcda63a + 917a0e4 commit 8562e0c

File tree

10 files changed

+264
-237
lines changed

10 files changed

+264
-237
lines changed

distarray/dist/distarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, distribution, dtype=float):
4242
# FIXME: code duplication with context.py.
4343
ctx = distribution.context
4444
# FIXME: this is bad...
45-
comm_name = ctx.comm
45+
comm_name = distribution.comm
4646
# FIXME: and this is bad...
4747
da_key = ctx._generate_key()
4848
ddpr = distribution.get_dim_data_per_rank()

distarray/dist/maps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def __init__(self, context, global_dim_data, targets=None):
580580
self.grid_shape = tuple(m.grid_size for m in self.maps)
581581

582582
self.grid_shape = normalize_grid_shape(self.grid_shape, self.ndim,
583-
self.dist, len(context.targets))
583+
self.dist, len(self.targets))
584584

585585
nelts = reduce(operator.mul, self.grid_shape, 1)
586586
self.rank_from_coords = np.arange(nelts).reshape(self.grid_shape)

distarray/dist/tests/test_context.py

Lines changed: 35 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,22 @@
1717

1818
import numpy
1919

20+
from distarray.testing import ContextTestCase, check_targets
2021
from distarray.dist.context import Context
2122
from distarray.dist.maps import Distribution
2223
from distarray.dist.ipython_utils import IPythonClient
2324
from distarray.local import LocalArray
2425

2526

26-
class TestContext(unittest.TestCase):
27+
class TestContext(ContextTestCase):
2728
"""Test Context methods"""
2829

2930
@classmethod
3031
def setUpClass(cls):
31-
cls.context = Context()
32+
super(TestContext, cls).setUpClass()
3233
cls.ndarr = numpy.arange(16).reshape(4, 4)
3334
cls.darr = cls.context.fromndarray(cls.ndarr)
3435

35-
@classmethod
36-
def tearDownClass(cls):
37-
"""Close the client connections"""
38-
cls.context.close()
39-
4036
def test_get_localarrays(self):
4137
las = self.darr.get_localarrays()
4238
self.assertIsInstance(las[0], LocalArray)
@@ -49,68 +45,61 @@ def test_get_ndarrays(self):
4945
class TestContextCreation(unittest.TestCase):
5046
"""Test Context Creation"""
5147

48+
@classmethod
49+
def setUpClass(cls):
50+
cls.client = IPythonClient()
51+
52+
@classmethod
53+
def tearDownClass(cls):
54+
cls.client.close()
55+
5256
def test_create_Context(self):
5357
"""Can we create a plain vanilla context?"""
54-
client = IPythonClient()
55-
dac = Context(client)
56-
self.assertIs(dac.client, client)
57-
dac.close()
58-
client.close()
58+
dac = Context(self.client)
59+
self.assertIs(dac.client, self.client)
5960

6061
def test_create_Context_with_targets(self):
6162
"""Can we create a context with a subset of engines?"""
62-
client = IPythonClient()
63-
dac = Context(client, targets=[0, 1])
64-
self.assertIs(dac.client, client)
63+
check_targets(required=2, available=len(self.client))
64+
dac = Context(self.client, targets=[0, 1])
65+
self.assertIs(dac.client, self.client)
6566
dac.close()
66-
client.close()
6767

6868
def test_create_Context_with_targets_ranks(self):
6969
"""Check that the target <=> rank mapping is consistent."""
70-
client = IPythonClient()
70+
check_targets(required=4, available=len(self.client))
7171
targets = [3, 2]
72-
dac = Context(client, targets=targets)
72+
dac = Context(self.client, targets=targets)
7373
self.assertEqual(set(dac.targets), set(targets))
7474
dac.close()
75-
client.close()
7675

7776
def test_context_target_reordering(self):
7877
"""Are contexts' targets reordered in a consistent way?"""
79-
client = IPythonClient()
80-
orig_targets = client.ids
78+
orig_targets = self.client.ids
8179
targets1 = orig_targets[:]
8280
targets2 = orig_targets[:]
8381
shuffle(targets1)
8482
shuffle(targets2)
85-
ctx1 = Context(client, targets=targets1)
86-
ctx2 = Context(client, targets=targets2)
83+
ctx1 = Context(self.client, targets=targets1)
84+
ctx2 = Context(self.client, targets=targets2)
8785
self.assertEqual(ctx1.targets, ctx2.targets)
8886
ctx1.close()
8987
ctx2.close()
90-
client.close()
9188

9289
def test_create_delete_key(self):
9390
""" Check that a key can be created and then destroyed. """
94-
client = IPythonClient()
95-
dac = Context(client)
91+
dac = Context(self.client)
9692
# Create and push a key/value.
9793
key, value = dac._generate_key(), 'test'
9894
dac._push({key: value}, targets=dac.targets)
9995
# Delete the key.
10096
dac.delete_key(key)
10197
dac.close()
102-
client.close()
103-
10498

105-
class TestPrimeCluster(unittest.TestCase):
10699

107-
@classmethod
108-
def setUpClass(cls):
109-
cls.context = Context(targets=range(3))
100+
class TestPrimeCluster(ContextTestCase):
110101

111-
@classmethod
112-
def tearDownClass(cls):
113-
cls.context.close()
102+
ntargets = 3
114103

115104
def test_1D(self):
116105
d = Distribution.from_shape(self.context, (3,))
@@ -139,12 +128,9 @@ def test_3D(self):
139128
self.assertEqual(c.grid_shape, (1, 1, 3))
140129

141130

142-
class TestApply(unittest.TestCase):
131+
class TestApply(ContextTestCase):
143132

144-
@classmethod
145-
def setUpClass(cls):
146-
cls.context = Context()
147-
cls.num_targets = len(cls.context.targets)
133+
ntargets = 'any'
148134

149135
def test_apply_no_args(self):
150136

@@ -153,7 +139,7 @@ def foo():
153139

154140
val = self.context.apply(foo)
155141

156-
self.assertEqual(val, [42] * self.num_targets)
142+
self.assertEqual(val, [42] * self.ntargets)
157143

158144
def test_apply_pos_args(self):
159145

@@ -162,19 +148,19 @@ def foo(a, b, c):
162148

163149
# push all arguments
164150
val = self.context.apply(foo, (1, 2, 3))
165-
self.assertEqual(val, [6] * self.num_targets)
151+
self.assertEqual(val, [6] * self.ntargets)
166152

167153
# some local, some pushed
168154
local_thing = self.context._key_and_push(2)[0]
169155
val = self.context.apply(foo, (1, local_thing, 3))
170156

171-
self.assertEqual(val, [6] * self.num_targets)
157+
self.assertEqual(val, [6] * self.ntargets)
172158

173159
# all pushed
174160
local_args = self.context._key_and_push(1, 2, 3)
175161
val = self.context.apply(foo, local_args)
176162

177-
self.assertEqual(val, [6] * self.num_targets)
163+
self.assertEqual(val, [6] * self.ntargets)
178164

179165
def test_apply_kwargs(self):
180166

@@ -186,25 +172,25 @@ def foo(a, b, c=None, d=None):
186172
# empty kwargs
187173
val = self.context.apply(foo, (1, 2))
188174

189-
self.assertEqual(val, [0] * self.num_targets)
175+
self.assertEqual(val, [0] * self.ntargets)
190176

191177
# some empty
192178
val = self.context.apply(foo, (1, 2), {'d': 3})
193179

194-
self.assertEqual(val, [5] * self.num_targets)
180+
self.assertEqual(val, [5] * self.ntargets)
195181

196182
# all kwargs
197183
val = self.context.apply(foo, (1, 2), {'c': 2, 'd': 3})
198184

199-
self.assertEqual(val, [8] * self.num_targets)
185+
self.assertEqual(val, [8] * self.ntargets)
200186

201187
# now with local values
202188
local_a = self.context._key_and_push(1)[0]
203189
local_c = self.context._key_and_push(3)[0]
204190

205191
val = self.context.apply(foo, (local_a, 2), {'c': local_c, 'd': 3})
206192

207-
self.assertEqual(val, [9] * self.num_targets)
193+
self.assertEqual(val, [9] * self.ntargets)
208194

209195
def test_apply_proxyize(self):
210196

@@ -229,7 +215,7 @@ def bar(obj):
229215
return obj + 10
230216
val = self.context.apply(bar, (name,))
231217

232-
self.assertEqual(val, [20] * self.num_targets)
218+
self.assertEqual(val, [20] * self.ntargets)
233219

234220
def test_apply_proxyize_sync(self):
235221

distarray/dist/tests/test_decorators.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy
1818
from numpy.testing import assert_array_equal
1919

20+
from distarray.testing import ContextTestCase, check_targets
2021
from distarray.dist.context import Context
2122
from distarray.dist.maps import Distribution
2223
from distarray.dist.decorators import DecoratorBase, local, vectorize
@@ -78,7 +79,9 @@ def dummy_func(*args, **kwargs):
7879
self.assertTrue(db.key in kw_keys2)
7980

8081

81-
class TestLocalDecorator(TestCase):
82+
class TestLocalDecorator(ContextTestCase):
83+
84+
ntargets = 'any'
8285

8386
# Functions for @local decorator tests. These are here so we can
8487
# guarantee they are pushed to the engines before we try to use them.
@@ -143,15 +146,11 @@ def parameterless():
143146

144147
@classmethod
145148
def setUpClass(cls):
146-
cls.context = Context()
149+
super(TestLocalDecorator, cls).setUpClass()
147150
distribution = Distribution.from_shape(cls.context, (5, 5))
148151
cls.da = cls.context.empty(distribution)
149152
cls.da.fill(2 * numpy.pi)
150153

151-
@classmethod
152-
def tearDownClass(cls):
153-
cls.context.close()
154-
155154
def test_local(self):
156155
"""Test the @local decorator"""
157156
context = Context()
@@ -178,6 +177,8 @@ def fill_da(da):
178177
assert_array_equal(da.toarray(), a)
179178

180179
def test_different_contexts(self):
180+
check_targets(required=4, available=len(self.context.targets))
181+
181182
ctx1 = Context(targets=range(4))
182183
ctx2 = Context(targets=range(3))
183184
distribution1 = Distribution.from_shape(ctx1, (10,))
@@ -200,6 +201,8 @@ def test_local_add50(self):
200201

201202
def test_local_sum(self):
202203
dd = self.local_sum(self.da)
204+
if self.ntargets == 1:
205+
dd = [dd]
203206
lshapes = self.da.get_localshapes()
204207
expected = []
205208
for lshape in lshapes:

0 commit comments

Comments
 (0)