32
32
33
33
34
34
class TestCollectiveAPIRunnerBase (object ):
35
+
35
36
def get_model (self , train_prog , startup_prog , rank , indata = None ):
36
37
raise NotImplementedError (
37
38
"get model should be implemented by child class." )
@@ -91,6 +92,7 @@ def runtime_main(test_class, col_type):
91
92
92
93
93
94
class TestDistBase (unittest .TestCase ):
95
+
94
96
def setUp (self ):
95
97
self ._port_set = set ()
96
98
self ._trainers = 2
@@ -104,6 +106,7 @@ def tearDown(self):
104
106
self .temp_dir .cleanup ()
105
107
106
108
def _find_free_port (self ):
109
+
107
110
def __free_port ():
108
111
with closing (socket .socket (socket .AF_INET ,
109
112
socket .SOCK_STREAM )) as s :
@@ -168,17 +171,15 @@ def _run_cluster(self, model_file, envs):
168
171
tr0_pipe = open (path0 , "w" )
169
172
tr1_pipe = open (path1 , "w" )
170
173
#print(tr0_cmd)
171
- tr0_proc = subprocess .Popen (
172
- tr0_cmd .strip ().split (),
173
- stdout = subprocess .PIPE ,
174
- stderr = tr0_pipe ,
175
- env = env0 )
176
-
177
- tr1_proc = subprocess .Popen (
178
- tr0_cmd .strip ().split (),
179
- stdout = subprocess .PIPE ,
180
- stderr = tr1_pipe ,
181
- env = env1 )
174
+ tr0_proc = subprocess .Popen (tr0_cmd .strip ().split (),
175
+ stdout = subprocess .PIPE ,
176
+ stderr = tr0_pipe ,
177
+ env = env0 )
178
+
179
+ tr1_proc = subprocess .Popen (tr0_cmd .strip ().split (),
180
+ stdout = subprocess .PIPE ,
181
+ stderr = tr1_pipe ,
182
+ env = env1 )
182
183
183
184
tr0_out , tr0_err = tr0_proc .communicate ()
184
185
tr1_out , tr1_err = tr1_proc .communicate ()
@@ -220,8 +221,14 @@ def check_with_place(self,
220
221
required_envs ["GLOG_v" ] = "3"
221
222
required_envs ["GLOG_logtostderr" ] = "1"
222
223
required_envs ["GLOO_LOG_LEVEL" ] = "TRACE"
223
- tr0_out , tr1_out , pid0 , pid1 = self ._run_cluster (model_file ,
224
- required_envs )
224
+
225
+ if os .getenv ('NVIDIA_TF32_OVERRIDE' , '' ) is not None :
226
+ required_envs ['NVIDIA_TF32_OVERRIDE' ] = os .getenv (
227
+ 'NVIDIA_TF32_OVERRIDE' , '' )
228
+
229
+ tr0_out , tr1_out , pid0 , pid1 = self ._run_cluster (
230
+ model_file , required_envs )
231
+
225
232
np .random .seed (pid0 )
226
233
input1 = np .random .random ((10 , 1000 ))
227
234
np .random .seed (pid1 )
@@ -248,36 +255,33 @@ def check_with_place(self,
248
255
elif col_type == "allreduce" :
249
256
need_result = input1 + input2
250
257
self .assertTrue (
251
- np .allclose (
252
- tr0_out , need_result , rtol = 1e-05 , atol = 1e-05 ))
258
+ np .allclose (tr0_out , need_result , rtol = 1e-05 , atol = 1e-05 ))
253
259
self .assertTrue (
254
- np .allclose (
255
- tr1_out , need_result , rtol = 1e-05 , atol = 1e-05 ))
260
+ np .allclose (tr1_out , need_result , rtol = 1e-05 , atol = 1e-05 ))
256
261
elif col_type == "parallel_embedding" :
257
262
result_data = tr0_out [0 ]
258
263
np .random .seed (2020 )
259
264
need_result = np .random .rand (12 , 8 )
260
265
for i in range (result_data .shape [0 ]):
261
266
for j in range (result_data .shape [1 ]):
262
267
data = result_data [i ][j ]
263
- assert np .allclose (
264
- tr0_out [1 ][i ][j ], need_result [data ], atol = 1e-08 )
268
+ assert np .allclose (tr0_out [1 ][i ][j ],
269
+ need_result [data ],
270
+ atol = 1e-08 )
265
271
elif col_type == "row_parallel_linear" :
266
272
result_data = tr0_out [0 ]
267
273
np .random .seed (2020 )
268
274
weight = np .random .rand (1000 , 16 )
269
275
need_result = np .matmul (input1 , weight )
270
276
self .assertTrue (
271
- np .allclose (
272
- result_data , need_result , rtol = 1e-05 , atol = 1e-05 ))
277
+ np .allclose (result_data , need_result , rtol = 1e-05 , atol = 1e-05 ))
273
278
elif col_type == "column_parallel_linear" :
274
279
result_data = tr0_out [0 ]
275
280
np .random .seed (2020 )
276
281
weight = np .random .rand (1000 , 16 )
277
282
need_result = np .matmul (input1 , weight )
278
283
self .assertTrue (
279
- np .allclose (
280
- result_data , need_result , rtol = 1e-05 , atol = 1e-05 ))
284
+ np .allclose (result_data , need_result , rtol = 1e-05 , atol = 1e-05 ))
281
285
elif col_type == "alltoall" :
282
286
need_result1 = np .vstack ((input1 [0 :input1 .shape [0 ] // 2 , :],
283
287
input2 [0 :input2 .shape [0 ] // 2 , :]))
@@ -286,16 +290,13 @@ def check_with_place(self,
286
290
tr0_out = np .vstack (tr0_out )
287
291
tr1_out = np .vstack (tr1_out )
288
292
self .assertTrue (
289
- np .allclose (
290
- tr0_out , need_result1 , rtol = 1e-05 , atol = 1e-05 ))
293
+ np .allclose (tr0_out , need_result1 , rtol = 1e-05 , atol = 1e-05 ))
291
294
self .assertTrue (
292
- np .allclose (
293
- tr1_out , need_result2 , rtol = 1e-05 , atol = 1e-05 ))
295
+ np .allclose (tr1_out , need_result2 , rtol = 1e-05 , atol = 1e-05 ))
294
296
elif col_type == "sendrecv" :
295
297
result_data = tr1_out [0 ]
296
298
self .assertTrue (
297
- np .allclose (
298
- input1 , result_data , rtol = 1e-05 , atol = 1e-05 ))
299
+ np .allclose (input1 , result_data , rtol = 1e-05 , atol = 1e-05 ))
299
300
elif col_type == "global_gather" :
300
301
in_feat = 2
301
302
n_expert = 2
@@ -372,15 +373,13 @@ def check_with_place(self,
372
373
if result1 == []:
373
374
output1 = np .array ([])
374
375
else :
375
- output1 = np .concatenate (
376
- result1 , axis = 0 ).reshape (
377
- sum (local_expert_count1 ), in_feat )
376
+ output1 = np .concatenate (result1 , axis = 0 ).reshape (
377
+ sum (local_expert_count1 ), in_feat )
378
378
if result2 == []:
379
379
output2 = np .array ([])
380
380
else :
381
- output2 = np .concatenate (
382
- result2 , axis = 0 ).reshape (
383
- sum (local_expert_count2 ), in_feat )
381
+ output2 = np .concatenate (result2 , axis = 0 ).reshape (
382
+ sum (local_expert_count2 ), in_feat )
384
383
385
384
if tr0_out [0 ] is None or tr0_out [0 ].shape [0 ] == 0 :
386
385
tr0_out [0 ] = np .array ([])
@@ -389,24 +388,20 @@ def check_with_place(self,
389
388
tr1_out [0 ] = np .array ([])
390
389
391
390
self .assertTrue (
392
- np .allclose (
393
- tr0_out [0 ], output1 , rtol = 1e-05 , atol = 1e-05 ))
391
+ np .allclose (tr0_out [0 ], output1 , rtol = 1e-05 , atol = 1e-05 ))
394
392
self .assertTrue (
395
- np .allclose (
396
- tr1_out [0 ], output2 , rtol = 1e-05 , atol = 1e-05 ))
393
+ np .allclose (tr1_out [0 ], output2 , rtol = 1e-05 , atol = 1e-05 ))
397
394
if static_mode == 0 :
398
395
self .assertTrue (
399
- np .allclose (
400
- tr0_out [1 ],
401
- 2 * local_input_buf1 ,
402
- rtol = 1e-05 ,
403
- atol = 1e-05 ))
396
+ np .allclose (tr0_out [1 ],
397
+ 2 * local_input_buf1 ,
398
+ rtol = 1e-05 ,
399
+ atol = 1e-05 ))
404
400
self .assertTrue (
405
- np .allclose (
406
- tr1_out [1 ],
407
- 2 * local_input_buf2 ,
408
- rtol = 1e-05 ,
409
- atol = 1e-05 ))
401
+ np .allclose (tr1_out [1 ],
402
+ 2 * local_input_buf2 ,
403
+ rtol = 1e-05 ,
404
+ atol = 1e-05 ))
410
405
411
406
elif col_type == "global_scatter" :
412
407
np .random .seed (pid0 )
@@ -460,23 +455,19 @@ def check_with_place(self,
460
455
tr1_out [0 ] = np .array ([])
461
456
462
457
self .assertTrue (
463
- np .allclose (
464
- tr0_out [0 ], output1 , rtol = 1e-05 , atol = 1e-05 ))
458
+ np .allclose (tr0_out [0 ], output1 , rtol = 1e-05 , atol = 1e-05 ))
465
459
self .assertTrue (
466
- np .allclose (
467
- tr1_out [0 ], output2 , rtol = 1e-05 , atol = 1e-05 ))
460
+ np .allclose (tr1_out [0 ], output2 , rtol = 1e-05 , atol = 1e-05 ))
468
461
if static_mode == 0 :
469
462
self .assertTrue (
470
- np .allclose (
471
- tr0_out [1 ],
472
- 2 * local_input_buf1 ,
473
- rtol = 1e-05 ,
474
- atol = 1e-05 ))
463
+ np .allclose (tr0_out [1 ],
464
+ 2 * local_input_buf1 ,
465
+ rtol = 1e-05 ,
466
+ atol = 1e-05 ))
475
467
self .assertTrue (
476
- np .allclose (
477
- tr1_out [1 ],
478
- 2 * local_input_buf2 ,
479
- rtol = 1e-05 ,
480
- atol = 1e-05 ))
468
+ np .allclose (tr1_out [1 ],
469
+ 2 * local_input_buf2 ,
470
+ rtol = 1e-05 ,
471
+ atol = 1e-05 ))
481
472
else :
482
473
pass
0 commit comments