Skip to content

Commit 961fbce

Browse files
committed
follow comments
1 parent 7b72383 commit 961fbce

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@ cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
1212
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1313

1414
if(WITH_GPU)
15-
nv_library(nccl_all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
15+
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
1616
dynload_cuda variable_visitor)
17-
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
1817
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
1918
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
2019

2120
else()
2221
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
2322
variable_visitor)
24-
set(multi_devices_graph_builder_deps all_reduce_op_handle)
2523
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
2624
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2725
endif()
@@ -30,7 +28,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
3028
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
3129

3230
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
33-
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
31+
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle)
3432

3533

3634
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
240240
CreateReduceOp(&result, g_name, 0);
241241
CreateBroadcastOp(&result, g_name, 0);
242242
} else {
243-
InsertNCCLAllReduceOp(&result, g_name);
243+
InsertAllReduceOp(&result, g_name);
244244
}
245245
break;
246246
}
@@ -327,8 +327,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
327327
CreateOpHandleIOs(result, op, dev_id);
328328
}
329329

330-
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
331-
SSAGraph *result, const std::string &og) const {
330+
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
331+
const std::string &og) const {
332332
#ifdef PADDLE_WITH_CUDA
333333
result->ops_.emplace_back(
334334
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
100100
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
101101
const OpDesc &op) const;
102102

103-
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
103+
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
104104

105105
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
106106
size_t src_dev_id) const;

0 commit comments

Comments
 (0)