Skip to content

Commit 98c12b1

Browse files
panyx0718Yang Yang(Tony)
authored andcommitted
Clean up C++ codes. (#10022)
* Privatize OpHandleBase * Clean up a few private members
1 parent 777cb55 commit 98c12b1

14 files changed

+85
-54
lines changed

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ namespace framework {
2929
namespace details {
3030

3131
struct BroadcastOpHandle : public OpHandleBase {
32-
const std::vector<Scope *> &local_scopes_;
33-
const std::vector<platform::Place> &places_;
34-
32+
public:
3533
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
3634
const std::vector<platform::Place> &places);
3735

@@ -41,6 +39,10 @@ struct BroadcastOpHandle : public OpHandleBase {
4139

4240
protected:
4341
void RunImpl() override;
42+
43+
private:
44+
const std::vector<Scope *> &local_scopes_;
45+
const std::vector<platform::Place> &places_;
4446
};
4547

4648
} // namespace details

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct TestBroadcastOpHandle {
9090
op_handle_->AddInput(dummy_var_handle);
9191

9292
for (size_t j = 0; j < gpu_list_.size(); ++j) {
93-
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
93+
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
9494
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
9595
vars_.emplace_back(out_var_handle);
9696
op_handle_->AddOutput(out_var_handle);

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
2828
void ComputationOpHandle::RunImpl() {
2929
auto *cur_ctx = dev_ctxes_[place_];
3030
for (auto *in : inputs_) {
31-
bool need_wait =
32-
in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx;
31+
bool need_wait = in->generated_op_ &&
32+
in->generated_op_->DeviceContext(place_) != cur_ctx;
3333
if (need_wait) {
3434
in->generated_op_->Wait(cur_ctx);
3535
}

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#pragma once
1616

17+
#include <string>
18+
#include <vector>
19+
1720
#include "paddle/fluid/framework/details/op_handle_base.h"
1821
#include "paddle/fluid/framework/op_registry.h"
1922
#include "paddle/fluid/framework/operator.h"
@@ -24,17 +27,19 @@ namespace paddle {
2427
namespace framework {
2528
namespace details {
2629
struct ComputationOpHandle : public OpHandleBase {
27-
std::unique_ptr<OperatorBase> op_;
28-
Scope *scope_;
29-
platform::Place place_;
30-
30+
public:
3131
ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
3232
platform::Place place);
3333

3434
std::string Name() const override;
3535

3636
protected:
3737
void RunImpl() override;
38+
39+
private:
40+
std::unique_ptr<OperatorBase> op_;
41+
Scope *scope_;
42+
platform::Place place_;
3843
};
3944
} // namespace details
4045
} // namespace framework

paddle/fluid/framework/details/fetch_op_handle.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#pragma once
1616

17+
#include <string>
18+
#include <vector>
19+
1720
#include "paddle/fluid/framework/details/op_handle_base.h"
1821
#include "paddle/fluid/framework/feed_fetch_type.h"
1922
#include "paddle/fluid/framework/scope.h"
@@ -24,11 +27,7 @@ namespace framework {
2427
namespace details {
2528

2629
struct FetchOpHandle : public OpHandleBase {
27-
FeedFetchList *data_;
28-
size_t offset_;
29-
std::vector<Scope *> *local_scopes_;
30-
std::vector<LoDTensor> tensors_;
31-
30+
public:
3231
FetchOpHandle(FeedFetchList *data, size_t offset,
3332
std::vector<Scope *> *local_scopes);
3433

@@ -42,6 +41,12 @@ struct FetchOpHandle : public OpHandleBase {
4241

4342
protected:
4443
void RunImpl() override;
44+
45+
private:
46+
FeedFetchList *data_;
47+
size_t offset_;
48+
std::vector<Scope *> *local_scopes_;
49+
std::vector<LoDTensor> tensors_;
4550
};
4651

4752
} // namespace details

paddle/fluid/framework/details/gather_op_handle.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ namespace framework {
2929
namespace details {
3030

3131
struct GatherOpHandle : public OpHandleBase {
32-
const std::vector<Scope *> &local_scopes_;
33-
const std::vector<platform::Place> &places_;
34-
32+
public:
3533
GatherOpHandle(const std::vector<Scope *> &local_scopes,
3634
const std::vector<platform::Place> &places);
3735

@@ -41,6 +39,10 @@ struct GatherOpHandle : public OpHandleBase {
4139

4240
protected:
4341
void RunImpl() override;
42+
43+
private:
44+
const std::vector<Scope *> &local_scopes_;
45+
const std::vector<platform::Place> &places_;
4446
};
4547

4648
} // namespace details

paddle/fluid/framework/details/gather_op_handle_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct TestGatherOpHandle {
7878
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_));
7979
// add input
8080
for (size_t j = 0; j < gpu_list_.size(); ++j) {
81-
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
81+
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
8282
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
8383
vars_.emplace_back(in_var_handle);
8484
op_handle_->AddInput(in_var_handle);

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
6060
const platform::Place &p,
6161
const size_t &i) const {
6262
auto *op_handle = result->ops_.back().get();
63-
op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p);
63+
op_handle->SetDeviceContext(p,
64+
platform::DeviceContextPool::Instance().Get(p));
6465

6566
auto var_names = op.InputArgumentNames();
6667

paddle/fluid/framework/details/nccl_all_reduce_op_handle.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ namespace framework {
2727
namespace details {
2828

2929
struct NCCLAllReduceOpHandle : public OpHandleBase {
30-
const std::vector<Scope *> &local_scopes_;
31-
const std::vector<platform::Place> &places_;
32-
const platform::NCCLContextMap &nccl_ctxs_;
33-
3430
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
3531
const std::vector<platform::Place> &places,
3632
const platform::NCCLContextMap &ctxs);
@@ -43,6 +39,11 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
4339

4440
protected:
4541
void RunImpl() override;
42+
43+
private:
44+
const std::vector<Scope *> &local_scopes_;
45+
const std::vector<platform::Place> &places_;
46+
const platform::NCCLContextMap &nccl_ctxs_;
4647
};
4748

4849
} // namespace details

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,15 @@ namespace details {
2727
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
2828

2929
class OpHandleBase {
30-
private:
31-
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
32-
3330
public:
34-
std::vector<VarHandleBase *> inputs_;
35-
std::vector<VarHandleBase *> outputs_;
36-
std::unordered_map<platform::Place, platform::DeviceContext *,
37-
platform::PlaceHash>
38-
dev_ctxes_;
39-
40-
#ifdef PADDLE_WITH_CUDA
41-
std::unordered_map<int, cudaEvent_t> events_;
42-
#endif
43-
4431
OpHandleBase() {}
4532

33+
virtual ~OpHandleBase();
34+
4635
std::string DebugString() const;
4736

4837
virtual std::string Name() const = 0;
4938

50-
virtual ~OpHandleBase();
51-
5239
void Run(bool use_event);
5340

5441
virtual void Wait(platform::DeviceContext *waited_dev);
@@ -61,13 +48,37 @@ class OpHandleBase {
6148
// will likely block other computations.
6249
virtual bool IsMultiDeviceTransfer() { return false; }
6350

51+
const platform::DeviceContext *DeviceContext(platform::Place place) {
52+
return dev_ctxes_[place];
53+
}
54+
55+
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
56+
dev_ctxes_[place] = ctx_;
57+
}
58+
59+
const std::vector<VarHandleBase *> &Inputs() const { return inputs_; }
60+
61+
const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }
62+
6463
protected:
6564
void RunAndRecordEvent(const std::function<void()> &callback);
6665

6766
void RunAndRecordEvent(platform::Place p,
6867
const std::function<void()> &callback);
6968

7069
virtual void RunImpl() = 0;
70+
71+
std::vector<VarHandleBase *> inputs_;
72+
std::vector<VarHandleBase *> outputs_;
73+
std::unordered_map<platform::Place, platform::DeviceContext *,
74+
platform::PlaceHash>
75+
dev_ctxes_;
76+
77+
#ifdef PADDLE_WITH_CUDA
78+
std::unordered_map<int, cudaEvent_t> events_;
79+
#endif
80+
81+
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
7182
};
7283

7384
} // namespace details

0 commit comments

Comments
 (0)