Skip to content

Commit 515b206

Browse files
wangxicodinggongweibao
authored andcommitted
[Cherry-pick 1.6] fix batch_norm_grad shape=0 & allreduce shape enforce & sync_batch_norm hang in fleet (#22157)
1 parent b9a1d95 commit 515b206

File tree

7 files changed

+87
-12
lines changed

7 files changed

+87
-12
lines changed

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ void AllReduceOpHandle::AllReduceImpl(
8484

8585
if (i == 0) {
8686
numel = static_cast<int64_t>(lod_tensor.numel());
87+
// only enforce place0, we will enforce other palce numel == place0 numel
88+
PADDLE_ENFORCE_GT(
89+
numel, 0, platform::errors::InvalidArgument(
90+
"The numel of tensos=[%s] must > 0. But now numel=[%d]",
91+
in_var_handles[i]->name(), numel));
8792
dtype = lod_tensor.type();
8893
is_gpu_place = platform::is_gpu_place(lod_tensor.place());
8994
}

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,17 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
425425

426426
// check output
427427
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
428-
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
429-
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
430-
"Output(Scale@GRAD) and Output(Bias@GRAD) should not be "
431-
"null at same time");
432-
}
428+
429+
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
430+
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
431+
432+
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
433+
platform::errors::InvalidArgument(
434+
"Output(Scale@GRAD) and Output(Bias@GRAD) must be null "
435+
"or not be null at same time. But now, "
436+
"has Scale@Grad=[%d], has Bias@GRAD=[%d]",
437+
has_scale_grad, has_bias_grad));
438+
433439
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
434440
if (use_global_stats) {
435441
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
@@ -444,7 +450,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
444450
: x_dims[x_dims.size() - 1]);
445451

446452
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
447-
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
453+
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
454+
if (has_scale_grad) {
448455
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
449456
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
450457
}

python/paddle/fluid/incubate/fleet/collective/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,17 @@ def _try_to_compile(self, startup_program, main_program):
297297
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
298298
)
299299

300+
# NOTE. open sync_batch_norm will hang when use multi num_threads
301+
sync_batch_norm = self._strategy.sync_batch_norm
302+
if sync_batch_norm is not None and sync_batch_norm is True:
303+
self._strategy.nccl_comm_num = 1
304+
self._strategy.use_hierarchical_allreduce = False
305+
exec_strategy.num_threads = 1
306+
logging.warn(
307+
"use sync_batch_norm will hang when set num_threads > 1, so "
308+
"set num_threads=1, nccl_comm_num=1, use_hierarchical_allreduce=False."
309+
)
310+
300311
if self.print_config:
301312
print("node_num:", node_num, "num_threads:",
302313
exec_strategy.num_threads, "use_hierarchical_allreduce:",

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,20 @@ function(bash_test_modules TARGET_NAME)
126126
set(timeout ${bash_test_modules_TIMEOUT})
127127
endif()
128128

129-
add_test(NAME ${TARGET_NAME}
130-
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
131-
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
132-
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
133-
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
129+
if(WITH_COVERAGE)
130+
add_test(NAME ${TARGET_NAME}
131+
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
132+
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
133+
WITH_COVERAGE=ON COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
134+
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
135+
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
136+
else()
137+
add_test(NAME ${TARGET_NAME}
138+
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
139+
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
140+
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
141+
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
142+
endif()
134143

135144
if (bash_test_modules_SERIAL)
136145
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)

python/paddle/fluid/tests/unittests/dist_test.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@ rm -f ${name}*.log
2020
# start the unit test
2121
run_time=$(( $TEST_TIMEOUT - 10 ))
2222
echo "run_time: ${run_time}"
23-
timeout -s SIGKILL ${run_time} python -u ${name}.py > ${name}_run.log 2>&1
23+
24+
if [[ ${WITH_COVERAGE} == "ON" ]]; then
25+
PYTHON_EXEC="python -u -m coverage run --branch -p "
26+
else
27+
PYTHON_EXEC="python -u "
28+
fi
29+
30+
timeout -s SIGKILL ${run_time} ${PYTHON_EXEC} ${name}.py > ${name}_run.log 2>&1
31+
2432
exit_code=$?
2533
if [[ $exit_code -eq 0 ]]; then
2634
exit 0

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def run_gpu_fleet_api_trainer(self, args):
136136
dist_strategy.use_local_sgd = True
137137
if args.ut4grad_allreduce:
138138
dist_strategy._ut4grad_allreduce = True
139+
if args.sync_batch_norm:
140+
dist_strategy.sync_batch_norm = True
139141

140142
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
141143
fleet.init(role)
@@ -445,6 +447,7 @@ def runtime_main(test_class):
445447
required=False,
446448
type=bool,
447449
default=False)
450+
parser.add_argument('--sync_batch_norm', action='store_true')
448451

449452
args = parser.parse_args()
450453

@@ -776,6 +779,8 @@ def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id,
776779
tr_cmd += " --use_local_sgd"
777780
if self._ut4grad_allreduce:
778781
tr_cmd += " --ut4grad_allreduce"
782+
if hasattr(self, '_sync_batch_norm') and self._sync_batch_norm:
783+
tr_cmd += " --sync_batch_norm"
779784

780785
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
781786
env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')

python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,42 @@ def _setup_config(self):
2424
self._use_reader_alloc = False
2525
self._nccl2_mode = True
2626
self._gpu_fleet_api = True
27+
self._sync_batch_norm = True
2728

2829
def test_dist_train(self):
2930
import paddle.fluid as fluid
3031
if fluid.core.is_compiled_with_cuda():
3132
self.check_with_place("dist_mnist.py", delta=1e-5)
3233

3334

35+
class FleetCollectiveTest(unittest.TestCase):
36+
def test_open_sync_batch_norm(self):
37+
import paddle.fluid as fluid
38+
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
39+
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
40+
41+
if not fluid.core.is_compiled_with_cuda():
42+
# Operator "gen_nccl_id" has not been registered
43+
return
44+
45+
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
46+
hidden = fluid.layers.fc(input=data, size=10)
47+
loss = fluid.layers.mean(hidden)
48+
49+
optimizer = fluid.optimizer.AdamOptimizer()
50+
51+
role = role_maker.UserDefinedCollectiveRoleMaker(0, ['127.0.0.1:6170'])
52+
fleet.init(role)
53+
54+
dist_strategy = DistributedStrategy()
55+
dist_strategy.sync_batch_norm = True
56+
57+
dist_optimizer = fleet.distributed_optimizer(
58+
optimizer, strategy=dist_strategy)
59+
dist_optimizer.minimize(loss)
60+
61+
self.assertEqual(dist_strategy.exec_strategy.num_threads, 1)
62+
63+
3464
if __name__ == "__main__":
3565
unittest.main()

0 commit comments

Comments
 (0)