33
33
reason = "Backend specific test" ,
34
34
)
35
35
class JaxDistributionLibTest (testing .TestCase ):
36
+ def _require_min_devices (self , min_devices ):
37
+ """Skip test if fewer than min_devices are available."""
38
+ if len (jax .devices ()) < min_devices :
39
+ pytest .skip (
40
+ f"Test requires at least { min_devices } devices, "
41
+ f"but only { len (jax .devices ())} available"
42
+ )
43
+
36
44
def _create_jax_layout (self , sharding ):
37
45
# Use jax_layout.Format or jax_layout.Layout if available.
38
46
if hasattr (jax_layout , "Format" ):
@@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding):
43
51
return sharding
44
52
45
53
def test_list_devices (self ):
54
+ self ._require_min_devices (8 )
46
55
self .assertEqual (len (distribution_lib .list_devices ()), 8 )
47
56
self .assertEqual (len (distribution_lib .list_devices ("cpu" )), 8 )
48
57
self .assertEqual (len (distribution_lib .list_devices ("cpu" )), 8 )
@@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize):
77
86
)
78
87
79
88
def test_distribute_tensor (self ):
89
+ self ._require_min_devices (8 )
80
90
jax_mesh = jax .sharding .Mesh (
81
91
np .array (jax .devices ()).reshape (2 , 4 ), ("batch" , "model" )
82
92
)
@@ -101,6 +111,7 @@ def test_function(inputs, target_layout):
101
111
self .assertTrue (result .sharding .is_equivalent_to (target_layout , ndim = 2 ))
102
112
103
113
def test_distribute_variable (self ):
114
+ self ._require_min_devices (8 )
104
115
# This test only verify the single worker/process behavior.
105
116
jax_mesh = jax .sharding .Mesh (
106
117
np .array (jax .devices ()).reshape (2 , 4 ), ("batch" , "model" )
@@ -118,6 +129,7 @@ def test_distribute_variable(self):
118
129
self .assertTrue (result .sharding .is_equivalent_to (target_layout , ndim = 2 ))
119
130
120
131
def test_distribute_input_data (self ):
132
+ self ._require_min_devices (8 )
121
133
# This test only verify the single worker/process behavior.
122
134
# The multi-process test lives in g3.
123
135
jax_mesh = jax .sharding .Mesh (
@@ -136,6 +148,7 @@ def test_distribute_input_data(self):
136
148
self .assertTrue (result .sharding .is_equivalent_to (target_layout , ndim = 2 ))
137
149
138
150
def test_distribute_tensor_with_jax_layout (self ):
151
+ self ._require_min_devices (8 )
139
152
jax_mesh = jax .sharding .Mesh (
140
153
np .array (jax .devices ()).reshape (2 , 4 ), ("batch" , "model" )
141
154
)
@@ -166,6 +179,7 @@ def test_function(inputs, target_layout):
166
179
)
167
180
168
181
def test_distribute_variable_with_jax_layout (self ):
182
+ self ._require_min_devices (8 )
169
183
# This test only verify the single worker/process behavior.
170
184
jax_mesh = jax .sharding .Mesh (
171
185
np .array (jax .devices ()).reshape (2 , 4 ), ("batch" , "model" )
@@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self):
187
201
)
188
202
189
203
def test_distribute_input_data_with_jax_layout (self ):
204
+ self ._require_min_devices (8 )
190
205
# This test only verify the single worker/process behavior.
191
206
jax_mesh = jax .sharding .Mesh (
192
207
np .array (jax .devices ()).reshape (2 , 4 ), ("batch" , "model" )
@@ -212,6 +227,7 @@ def test_processes(self):
212
227
self .assertEqual (backend_dlib .num_processes (), 1 )
213
228
214
229
def test_to_backend_mesh (self ):
230
+ self ._require_min_devices (8 )
215
231
devices = [f"cpu:{ i } " for i in range (8 )]
216
232
shape = (4 , 2 )
217
233
axis_names = ["batch" , "model" ]
@@ -224,6 +240,7 @@ def test_to_backend_mesh(self):
224
240
self .assertEqual (jax_mesh .axis_names , ("batch" , "model" ))
225
241
226
242
def test_to_backend_layout (self ):
243
+ self ._require_min_devices (8 )
227
244
axes = ["data" , None ]
228
245
mesh = distribution_lib .DeviceMesh (
229
246
(4 , 2 ), ["data" , "model" ], [f"cpu:{ i } " for i in range (8 )]
@@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self):
248
265
backend_dlib ._to_backend_layout (layout )
249
266
250
267
def test_variable_assignment_reuse_layout (self ):
268
+ self ._require_min_devices (8 )
251
269
shape = (4 , 2 )
252
270
axis_names = ["batch" , "model" ]
253
271
device_mesh = distribution_lib .DeviceMesh (
@@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self):
310
328
model .fit (inputs , labels )
311
329
312
330
def test_e2e_model_parallel_model (self ):
331
+ self ._require_min_devices (8 )
313
332
shape = (4 , 2 )
314
333
axis_names = ["batch" , "model" ]
315
334
device_mesh = distribution_lib .DeviceMesh (
@@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self):
349
368
model .fit (inputs , labels )
350
369
351
370
def test_e2e_model_parallel_with_output_sharding (self ):
371
+ self ._require_min_devices (8 )
352
372
shape = (4 , 2 )
353
373
axis_names = ["batch" , "model" ]
354
374
device_mesh = distribution_lib .DeviceMesh (
@@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self):
405
425
)
406
426
407
427
def test_distribute_data_input (self ):
428
+ self ._require_min_devices (4 )
408
429
per_process_batch = jax .numpy .arange (24 ).reshape (
409
430
6 , 4
410
431
) # Example input array
0 commit comments