Skip to content

Commit 8a463cd

Browse files
committed
fix gpubox trainning scripts and fix cuda error
1 parent 1d385ad commit 8a463cd

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

tools/run_gpubox.sh

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,44 @@
1-
#!/bin/bash
1+
# !/bin/bash
2+
3+
if [ ! -d "./log" ]; then
4+
mkdir ./log
5+
echo "Create log floder for store running log"
6+
fi
27

38
export FLAGS_LAUNCH_BARRIER=0
4-
fleetrun --worker_num=1 --server_num=1 tools/static_gpubox_trainer.py -m models/rank/dnn/config_gpubox.yaml
9+
export PADDLE_TRAINER_ID=0
10+
export PADDLE_PSERVER_NUMS=1
11+
export PADDLE_TRAINERS=1
12+
export PADDLE_TRAINERS_NUM=${PADDLE_TRAINERS}
13+
14+
15+
# set free port if 29011 is occupied
16+
export PADDLE_PSERVERS_IP_PORT_LIST="127.0.0.1:29011"
17+
export PADDLE_PSERVER_PORT_ARRAY=(29011)
18+
19+
# set gpu numbers according to your device
20+
export FLAGS_selected_gpus="0,1,2,3,4,5,6,7"
21+
22+
# set your model yaml
23+
SC="tools/static_gpubox_trainer.py -m models/rank/dnn/config_gpubox.yaml"
24+
25+
# run pserver
26+
export TRAINING_ROLE=PSERVER
27+
for((i=0;i<$PADDLE_PSERVER_NUMS;i++))
28+
do
29+
cur_port=${PADDLE_PSERVER_PORT_ARRAY[$i]}
30+
echo "PADDLE WILL START PSERVER "$cur_port
31+
export PADDLE_PORT=${cur_port}
32+
python -u $SC &> ./log/pserver.$i.log &
33+
done
34+
35+
# run trainer
36+
export TRAINING_ROLE=TRAINER
37+
for((i=0;i<$PADDLE_TRAINERS;i++))
38+
do
39+
echo "PADDLE WILL START Trainer "$i
40+
PADDLE_TRAINER_ID=$i
41+
python -u $SC &> ./log/worker.$i.log
42+
done
43+
44+
echo "Training log stored in ./log/"

tools/static_gpubox_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def run(self):
7474
fleet.stop_worker()
7575
self.record_result()
7676
logger.info("Run Success, Exit.")
77+
logger.info("-" * 100)
7778

7879
def network(self):
7980
self.model = get_model(self.config)
@@ -150,6 +151,9 @@ def run_worker(self):
150151
self.exe, model_dir,
151152
[feed.name for feed in self.input_data],
152153
self.inference_target_var)
154+
self.reader.release_memory()
155+
self.PSGPU.end_pass()
156+
logger.info("finish {} epoch training....".format(epoch))
153157

154158
def init_reader(self):
155159
if fleet.is_server():

0 commit comments

Comments
 (0)