1717
1818import numpy
1919
20+ from distarray .testing import ContextTestCase , check_targets
2021from distarray .dist .context import Context
2122from distarray .dist .maps import Distribution
2223from distarray .dist .ipython_utils import IPythonClient
2324from distarray .local import LocalArray
2425
2526
26- class TestContext (unittest . TestCase ):
27+ class TestContext (ContextTestCase ):
2728 """Test Context methods"""
2829
2930 @classmethod
3031 def setUpClass (cls ):
31- cls . context = Context ()
32+ super ( TestContext , cls ). setUpClass ()
3233 cls .ndarr = numpy .arange (16 ).reshape (4 , 4 )
3334 cls .darr = cls .context .fromndarray (cls .ndarr )
3435
35- @classmethod
36- def tearDownClass (cls ):
37- """Close the client connections"""
38- cls .context .close ()
39-
4036 def test_get_localarrays (self ):
4137 las = self .darr .get_localarrays ()
4238 self .assertIsInstance (las [0 ], LocalArray )
@@ -49,68 +45,61 @@ def test_get_ndarrays(self):
4945class TestContextCreation (unittest .TestCase ):
5046 """Test Context Creation"""
5147
48+ @classmethod
49+ def setUpClass (cls ):
50+ cls .client = IPythonClient ()
51+
52+ @classmethod
53+ def tearDownClass (cls ):
54+ cls .client .close ()
55+
5256 def test_create_Context (self ):
5357 """Can we create a plain vanilla context?"""
54- client = IPythonClient ()
55- dac = Context (client )
56- self .assertIs (dac .client , client )
57- dac .close ()
58- client .close ()
58+ dac = Context (self .client )
59+ self .assertIs (dac .client , self .client )
5960
6061 def test_create_Context_with_targets (self ):
6162 """Can we create a context with a subset of engines?"""
62- client = IPythonClient ( )
63- dac = Context (client , targets = [0 , 1 ])
64- self .assertIs (dac .client , client )
63+ check_targets ( required = 2 , available = len ( self . client ) )
64+ dac = Context (self . client , targets = [0 , 1 ])
65+ self .assertIs (dac .client , self . client )
6566 dac .close ()
66- client .close ()
6767
6868 def test_create_Context_with_targets_ranks (self ):
6969 """Check that the target <=> rank mapping is consistent."""
70- client = IPythonClient ( )
70+ check_targets ( required = 4 , available = len ( self . client ) )
7171 targets = [3 , 2 ]
72- dac = Context (client , targets = targets )
72+ dac = Context (self . client , targets = targets )
7373 self .assertEqual (set (dac .targets ), set (targets ))
7474 dac .close ()
75- client .close ()
7675
7776 def test_context_target_reordering (self ):
7877 """Are contexts' targets reordered in a consistent way?"""
79- client = IPythonClient ()
80- orig_targets = client .ids
78+ orig_targets = self .client .ids
8179 targets1 = orig_targets [:]
8280 targets2 = orig_targets [:]
8381 shuffle (targets1 )
8482 shuffle (targets2 )
85- ctx1 = Context (client , targets = targets1 )
86- ctx2 = Context (client , targets = targets2 )
83+ ctx1 = Context (self . client , targets = targets1 )
84+ ctx2 = Context (self . client , targets = targets2 )
8785 self .assertEqual (ctx1 .targets , ctx2 .targets )
8886 ctx1 .close ()
8987 ctx2 .close ()
90- client .close ()
9188
9289 def test_create_delete_key (self ):
9390 """ Check that a key can be created and then destroyed. """
94- client = IPythonClient ()
95- dac = Context (client )
91+ dac = Context (self .client )
9692 # Create and push a key/value.
9793 key , value = dac ._generate_key (), 'test'
9894 dac ._push ({key : value }, targets = dac .targets )
9995 # Delete the key.
10096 dac .delete_key (key )
10197 dac .close ()
102- client .close ()
103-
10498
105- class TestPrimeCluster (unittest .TestCase ):
10699
107- @classmethod
108- def setUpClass (cls ):
109- cls .context = Context (targets = range (3 ))
100+ class TestPrimeCluster (ContextTestCase ):
110101
111- @classmethod
112- def tearDownClass (cls ):
113- cls .context .close ()
102+ ntargets = 3
114103
115104 def test_1D (self ):
116105 d = Distribution .from_shape (self .context , (3 ,))
@@ -139,12 +128,9 @@ def test_3D(self):
139128 self .assertEqual (c .grid_shape , (1 , 1 , 3 ))
140129
141130
142- class TestApply (unittest . TestCase ):
131+ class TestApply (ContextTestCase ):
143132
144- @classmethod
145- def setUpClass (cls ):
146- cls .context = Context ()
147- cls .num_targets = len (cls .context .targets )
133+ ntargets = 'any'
148134
149135 def test_apply_no_args (self ):
150136
@@ -153,7 +139,7 @@ def foo():
153139
154140 val = self .context .apply (foo )
155141
156- self .assertEqual (val , [42 ] * self .num_targets )
142+ self .assertEqual (val , [42 ] * self .ntargets )
157143
158144 def test_apply_pos_args (self ):
159145
@@ -162,19 +148,19 @@ def foo(a, b, c):
162148
163149 # push all arguments
164150 val = self .context .apply (foo , (1 , 2 , 3 ))
165- self .assertEqual (val , [6 ] * self .num_targets )
151+ self .assertEqual (val , [6 ] * self .ntargets )
166152
167153 # some local, some pushed
168154 local_thing = self .context ._key_and_push (2 )[0 ]
169155 val = self .context .apply (foo , (1 , local_thing , 3 ))
170156
171- self .assertEqual (val , [6 ] * self .num_targets )
157+ self .assertEqual (val , [6 ] * self .ntargets )
172158
173159 # all pushed
174160 local_args = self .context ._key_and_push (1 , 2 , 3 )
175161 val = self .context .apply (foo , local_args )
176162
177- self .assertEqual (val , [6 ] * self .num_targets )
163+ self .assertEqual (val , [6 ] * self .ntargets )
178164
179165 def test_apply_kwargs (self ):
180166
@@ -186,25 +172,25 @@ def foo(a, b, c=None, d=None):
186172 # empty kwargs
187173 val = self .context .apply (foo , (1 , 2 ))
188174
189- self .assertEqual (val , [0 ] * self .num_targets )
175+ self .assertEqual (val , [0 ] * self .ntargets )
190176
191177 # some empty
192178 val = self .context .apply (foo , (1 , 2 ), {'d' : 3 })
193179
194- self .assertEqual (val , [5 ] * self .num_targets )
180+ self .assertEqual (val , [5 ] * self .ntargets )
195181
196182 # all kwargs
197183 val = self .context .apply (foo , (1 , 2 ), {'c' : 2 , 'd' : 3 })
198184
199- self .assertEqual (val , [8 ] * self .num_targets )
185+ self .assertEqual (val , [8 ] * self .ntargets )
200186
201187 # now with local values
202188 local_a = self .context ._key_and_push (1 )[0 ]
203189 local_c = self .context ._key_and_push (3 )[0 ]
204190
205191 val = self .context .apply (foo , (local_a , 2 ), {'c' : local_c , 'd' : 3 })
206192
207- self .assertEqual (val , [9 ] * self .num_targets )
193+ self .assertEqual (val , [9 ] * self .ntargets )
208194
209195 def test_apply_proxyize (self ):
210196
@@ -229,7 +215,7 @@ def bar(obj):
229215 return obj + 10
230216 val = self .context .apply (bar , (name ,))
231217
232- self .assertEqual (val , [20 ] * self .num_targets )
218+ self .assertEqual (val , [20 ] * self .ntargets )
233219
234220 def test_apply_proxyize_sync (self ):
235221
0 commit comments