Skip to content

Commit 0bf799a

Browse files
committed
wip testing
1 parent b9c28df commit 0bf799a

File tree

7 files changed

+16
-17
lines changed

7 files changed

+16
-17
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ else()
1616
set(multi_devices_graph_builder_deps)
1717
endif()
1818
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
19-
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
19+
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
2020
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
2121
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
2222
simple_threadpool device_context)

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
3535
const std::string &loss_var_name,
3636
const std::unordered_set<std::string> &params,
3737
const std::vector<Scope *> &local_scopes,
38-
platform::NCCLContextMap *nccl_ctxs, bool distributed)
38+
platform::NCCLContextMap *nccl_ctxs)
3939
: loss_var_name_(loss_var_name),
4040
places_(places),
4141
local_scopes_(local_scopes),
42-
distributed_(distributed),
4342
nccl_ctxs_(nccl_ctxs) {
4443
#else
4544
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
4645
const std::vector<platform::Place> &places,
4746
const std::string &loss_var_name,
4847
const std::unordered_set<std::string> &params,
49-
const std::vector<Scope *> &local_scopes, bool distributed)
48+
const std::vector<Scope *> &local_scopes)
5049
: loss_var_name_(loss_var_name),
5150
places_(places),
52-
local_scopes_(local_scopes),
53-
distributed_(distributed) {
51+
local_scopes_(local_scopes) {
5452
#endif
5553
for (auto &p : params) {
5654
grad_names_.insert(GradVarName(p));
@@ -99,7 +97,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
9997

10098
// append send op if program is distributed trainer main program.
10199
// always use the first device
102-
if (is_forwarding && distributed_ && op->Type() == "send") {
100+
if (!is_forwarding && op->Type() == "send") {
103101
auto &p = places_[0];
104102
auto *s = local_scopes_[0];
105103
size_t i = 0;

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
3434
const std::string &loss_var_name,
3535
const std::unordered_set<std::string> &params,
3636
const std::vector<Scope *> &local_scopes,
37-
platform::NCCLContextMap *nccl_ctxs,
38-
bool distributed = false);
37+
platform::NCCLContextMap *nccl_ctxs);
3938
#else
4039
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
4140
const std::string &loss_var_name,
4241
const std::unordered_set<std::string> &params,
43-
const std::vector<Scope *> &local_scopes,
44-
bool distributed = false);
42+
const std::vector<Scope *> &local_scopes);
4543
#endif
4644

4745
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
@@ -55,7 +53,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
5553
const std::vector<platform::Place> &places_;
5654
const std::vector<Scope *> &local_scopes_;
5755
std::unordered_set<std::string> grad_names_;
58-
bool distributed_;
5956

6057
#ifdef PADDLE_WITH_CUDA
6158
platform::NCCLContextMap *nccl_ctxs_;

paddle/fluid/framework/parallel_executor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ class ParallelExecutor {
4848
const std::string& fetched_var_name,
4949
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
5050

51+
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
52+
5153
private:
5254
void SplitTensorToPlaces(
5355
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
5456

5557
ParallelExecutorPrivate* member_;
56-
57-
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
5858
};
5959

6060
} // namespace framework

paddle/fluid/operators/detail/serde_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
107107
for (int i = 0; i < tensor_numel; ++i) {
108108
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
109109
}
110-
for (int64_t i = 0; i < rows2->size(); ++i) {
110+
for (size_t i = 0; i < rows2->size(); ++i) {
111111
EXPECT_EQ(rows_data2[i], i);
112112
}
113113
EXPECT_EQ(slr2->height(), 1000);

paddle/fluid/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ All parameter, weight, gradient are variables in Paddle.
554554
bcast_vars, main_program, loss_var_name,
555555
scope, local_scopes, allow_op_delay);
556556
})
557+
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
557558
.def("local_scopes",
558559
[](ParallelExecutor &self) -> std::vector<Scope *> * {
559560
return &self.GetLocalScopes();

python/paddle/fluid/parallel_executor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self,
9999
local_scopes = share_vars_from.executor.local_scopes(
100100
) if share_vars_from else []
101101

102-
persistable_vars = [
102+
self.persistable_vars = [
103103
v.name
104104
for v in filter(lambda var: var.persistable, main.list_vars())
105105
]
@@ -112,7 +112,7 @@ def __init__(self,
112112
p.name for p in main.global_block().iter_parameters()
113113
if not p.stop_gradient
114114
]),
115-
set(persistable_vars),
115+
set(self.persistable_vars),
116116
main.desc,
117117
loss_name if loss_name else '',
118118
scope,
@@ -142,3 +142,6 @@ def run(self, fetch_list, feed_dict={}):
142142
self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict)
143143
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
144144
return [arr[i] for i in range(len(arr))]
145+
146+
def bcast_params(self):
147+
self.executor.bcast_params(set(self.persistable_vars))

0 commit comments

Comments
 (0)