Skip to content

Commit e5b3220

Browse files
committed
clean
1 parent ec6ee0a commit e5b3220

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes";
127127
static const char kStrategy[] = "strategy";
128128

129129
void MultiDevSSAGraphBuilder::Init() const {
130+
all_vars_.clear();
131+
balance_vars_.clear();
132+
130133
loss_var_name_ = Get<const std::string>(kLossVarName);
131134
places_ = Get<const std::vector<platform::Place>>(kPlaces);
132135
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);

paddle/fluid/framework/details/multi_devices_graph_pass.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
4040
size_t device_id) const;
4141
void Init() const;
4242

43-
private:
44-
mutable std::string loss_var_name_;
45-
mutable std::vector<platform::Place> places_;
46-
mutable std::vector<Scope *> local_scopes_;
47-
mutable std::unordered_set<std::string> grad_names_;
48-
4943
#ifdef PADDLE_WITH_CUDA
5044
mutable platform::NCCLContextMap *nccl_ctxs_;
5145
#endif
@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
9589
size_t GetAppropriateDeviceID(
9690
const std::vector<std::string> &var_names) const;
9791

98-
private:
92+
void SetCommunicationContext(OpHandleBase *op_handle,
93+
const platform::Place &p) const;
94+
95+
mutable std::string loss_var_name_;
96+
mutable std::vector<platform::Place> places_;
97+
mutable std::vector<Scope *> local_scopes_;
98+
mutable std::unordered_set<std::string> grad_names_;
99+
99100
mutable BuildStrategy strategy_;
100101
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
101102
mutable std::vector<int64_t> balance_vars_;
102-
103-
void SetCommunicationContext(OpHandleBase *op_handle,
104-
const platform::Place &p) const;
105103
};
106104
} // namespace details
107105
} // namespace framework

0 commit comments

Comments
 (0)