Skip to content

Commit f26ba5b

Browse files
author
chengduo
authored
Fuse AllReduce (#15921)
* fuse all_reduce test=develop * add fuse_parameter_groups_size test=develop * Polish code test=develop * Fix travis-ci test=develop * Add SetGroupAccordingToLayers and SetGroupAccordingToGroupSize test=develop * Add SetGroupAccordingToMemorySize test=develop * fix multi_devices_graph test=develop * reset params_grads test=develop * Polish code test=develop
1 parent d0ef682 commit f26ba5b

15 files changed

+1185
-49
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
99
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
1010
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
1111
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
12+
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
1213

1314
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1415

@@ -22,6 +23,8 @@ endif()
2223
if(WITH_GPU)
2324
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
2425
dynload_cuda variable_visitor)
26+
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
27+
dynload_cuda variable_visitor)
2528
if(WITH_DISTRIBUTE)
2629
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
2730
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
@@ -35,6 +38,8 @@ if(WITH_GPU)
3538
else()
3639
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
3740
variable_visitor)
41+
cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
42+
variable_visitor)
3843
if(WITH_DISTRIBUTE)
3944
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
4045
ddim selected_rows_functor sendrecvop_rpc)
@@ -71,6 +76,8 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
7176
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
7277
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
7378

79+
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
80+
7481
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
7582
if (WITH_GPU)
7683
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
@@ -98,5 +105,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
98105
graph_viz_pass multi_devices_graph_pass
99106
multi_devices_graph_print_pass multi_devices_graph_check_pass
100107
fuse_elewise_add_act_pass multi_batch_merge_pass
101-
fuse_relu_depthwise_conv_pass
102-
memory_optimize_pass lock_free_optimize_pass)
108+
fuse_relu_depthwise_conv_pass
109+
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass)

0 commit comments

Comments
 (0)