Skip to content

Commit 03a05bd

Browse files
authored
Enable parallel training UT in GitHub CI. (#1075)
* Remove dupliated setting of `allow_growth` in trainer. * Make parallel training UT independent of its working folder. * Skip parallel-training tests when there is only 1 GPU card. * Enable parallel training UT in GitHub CI.
1 parent c716b9f commit 03a05bd

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

.github/workflows/test_python.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,9 @@ jobs:
7171
CXX: g++-${{ matrix.gcc }}
7272
TENSORFLOW_VERSION: ${{ matrix.tf }}
7373
- run: dp --version
74+
- name: Prepare parallel runtime
75+
if: ${{ matrix.tf == '' }}
76+
run: |
77+
sudo apt install libopenmpi-dev openmpi-bin
78+
HOROVOD_WITHOUT_GLOO=1 HOROVOD_WITH_TENSORFLOW=1 pip install horovod mpi4py
7479
- run: pytest --cov=deepmd source/tests && codecov

deepmd/train/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def _init_session(self):
397397
config = get_tf_session_config()
398398
device, idx = self.run_opt.my_device.split(":", 1)
399399
if device == "gpu":
400-
config.gpu_options.allow_growth = True
401400
config.gpu_options.visible_device_list = idx
402401
self.sess = tf.Session(config=config)
403402

source/tests/test_parallel_training.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@ def setUp(self):
1818
def test_two_workers(self):
1919
command = 'horovodrun -np 2 dp train -m workers ' + self.input_file
2020
penv = os.environ.copy()
21-
if len(get_gpus() or []) > 1:
21+
num_gpus = len(get_gpus() or [])
22+
if num_gpus > 1:
2223
penv['CUDA_VISIBLE_DEVICES'] = '0,1'
23-
popen = sp.Popen(command, shell=True, env=penv, stdout=sp.PIPE, stderr=sp.STDOUT)
24+
elif num_gpus == 1:
25+
raise unittest.SkipTest("At least 2 GPU cards are needed for parallel-training tests.")
26+
popen = sp.Popen(command, shell=True, cwd=str(tests_path), env=penv, stdout=sp.PIPE, stderr=sp.STDOUT)
2427
for line in iter(popen.stdout.readline, b''):
2528
if hasattr(line, 'decode'):
2629
line = line.decode('utf-8')

0 commit comments

Comments
 (0)