Skip to content

Commit 881e063

Browse files
committed
follow comments
1 parent ff599b9 commit 881e063

File tree

6 files changed

+46
-59
lines changed

6 files changed

+46
-59
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,42 +53,39 @@ void BroadcastOpHandle::RunImpl() {
5353

5454
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5555

56-
// NOTE(zcd): the Place of input can get from in_tensor and in_var_handle ,
57-
// maybe they are different, because the Place that getting from in_tensor is
58-
// determined at runtime, the other is determined at building SSA graph stage.
59-
// If they are different, DataTransform should be applied. Currently, it has
60-
// not been done yet.
56+
// NOTE: The tensors' Place of input and output must be all on GPU or all on
57+
// CPU.
6158
for (auto *out_var_handle : out_var_handles) {
62-
if (*out_var_handle == *in_var_handle) {
59+
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
6360
continue;
6461
}
65-
auto &out_p = out_var_handle->place_;
62+
auto t_out_p = out_var_handle->place_;
6663
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
6764
->FindVar(out_var_handle->name_);
6865
PADDLE_ENFORCE_NOT_NULL(out_var);
69-
PADDLE_ENFORCE_EQ(
70-
out_p.which(), in_tensor.place().which(),
71-
"Currently, Places of input and output must be all on CPU "
72-
"or all on GPU.");
66+
if (platform::is_gpu_place(in_tensor.place())) {
67+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
68+
"Places of input and output must be all on GPU.");
69+
} else {
70+
t_out_p = platform::CPUPlace();
71+
}
7372
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
74-
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
73+
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
7574
in_tensor.type());
7675
}
7776

7877
if (platform::is_cpu_place(in_tensor.place())) {
7978
for (auto *out_var_handle : out_var_handles) {
80-
if (*out_var_handle == *in_var_handle) {
79+
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
8180
continue;
8281
}
83-
8482
auto &out_p = out_var_handle->place_;
85-
auto dev_ctx = dev_ctxes_.at(out_p);
8683
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
8784
->FindVar(out_var_handle->name_);
8885

89-
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
86+
RunAndRecordEvent(out_p, [in_tensor, out_var] {
9087
paddle::framework::TensorCopy(
91-
in_tensor, out_p, *dev_ctx,
88+
in_tensor, platform::CPUPlace(),
9289
&VariableVisitor::GetMutableTensor(out_var));
9390
});
9491
}
@@ -134,8 +131,8 @@ void BroadcastOpHandle::RunImpl() {
134131
call();
135132
}
136133
}
137-
// TODO(zcd): Maybe the unequal operator is not appropriate here.
138-
if (*out_handle != *in_var_handle) {
134+
135+
if (!out_handle->IsTheSameVar(*in_var_handle)) {
139136
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
140137
->FindVar(out_var_handles[0]->name_);
141138
paddle::framework::TensorCopy(

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,15 @@ void GatherOpHandle::RunImpl() {
7575
in_tensors.emplace_back(in_sr_value.value());
7676
}
7777

78-
// TODO(zcd): The Place of var_handle is determined at building SSA graph
79-
// stage, while the Place of var is determined at runtime. If they are
80-
// different, DataTransform should be applied. Currently, it has not been done
81-
// yet.
82-
auto &out_place = out_var_handle->place_;
83-
PADDLE_ENFORCE_EQ(out_place.which(), pre_in_value.place().which(),
84-
"Currently, Places of input and output must be all on CPU "
85-
"or all on GPU.");
78+
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
79+
platform::Place t_out_p = out_var_handle->place_;
80+
if (platform::is_gpu_place(pre_in_value.place())) {
81+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
82+
"Places of input and output must be all on GPU.");
83+
} else {
84+
t_out_p = platform::CPUPlace();
85+
}
86+
8687
auto out_var =
8788
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
8889
PADDLE_ENFORCE_NOT_NULL(out_var);
@@ -93,18 +94,18 @@ void GatherOpHandle::RunImpl() {
9394
DDim out_dim = pre_in_value.GetCompleteDims();
9495
out_dim[0] = static_cast<int64_t>(rows);
9596
out_value->mutable_value()->Resize(out_dim).mutable_data(
96-
out_place, pre_in_value.value().type());
97+
t_out_p, pre_in_value.value().type());
9798
Tensor *out_tensor = out_value->mutable_value();
9899

99100
// copy
100-
auto dev_ctx = dev_ctxes_[out_place];
101-
RunAndRecordEvent(out_place, [in_tensors, out_tensor, &dev_ctx, out_place] {
101+
auto dev_ctx = dev_ctxes_[out_var_handle->place_];
102+
RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
103+
t_out_p] {
102104
int s = 0, e = 0;
103105
for (size_t j = 0; j < in_tensors.size(); ++j) {
104106
e += in_tensors[j].dims()[0];
105107
auto sub_out = out_tensor->Slice(s, e);
106-
paddle::framework::TensorCopy(in_tensors[j], out_place, *dev_ctx,
107-
&sub_out);
108+
paddle::framework::TensorCopy(in_tensors[j], t_out_p, *dev_ctx, &sub_out);
108109
s = e;
109110
}
110111
});

paddle/fluid/framework/details/reduce_op_handle.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void ReduceOpHandle::RunImpl() {
5353
// Wait input done, this Wait is asynchronous operation
5454
WaitInputVarGenerated(in_var_handles);
5555

56+
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
5657
std::vector<platform::Place> in_places; // used to get dev_ctx
5758
for (auto *in_handle : in_var_handles) {
5859
in_places.emplace_back(in_handle->place_);
@@ -66,22 +67,23 @@ void ReduceOpHandle::RunImpl() {
6667
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
6768
PADDLE_ENFORCE_NOT_NULL(out_var);
6869

69-
// TODO(zcd): The Place of var_handle is determined at building SSA graph
70-
// stage, while the Place of var is determined at runtime. If they are
71-
// different, DataTransform should be applied. Currently, it has not been done
72-
// yet.
73-
PADDLE_ENFORCE_EQ(
74-
VariableVisitor::GetMutableTensor(pre_in_var).place().which(),
75-
out_var_handle->place_.which(),
76-
"Currently, Places of input and output must be all on CPU or all on "
77-
"GPU.");
70+
// NOTE: The tensors' Place of input and output must be all on GPU or all on
71+
// CPU.
72+
auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place();
73+
platform::Place t_out_p;
74+
if (platform::is_gpu_place(in_p)) {
75+
PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place_),
76+
"Places of input and output must be all on GPU.");
77+
t_out_p = out_var_handle->place_;
78+
} else {
79+
t_out_p = platform::CPUPlace();
80+
}
7881

7982
if (pre_in_var->IsType<framework::SelectedRows>()) {
8083
std::vector<const SelectedRows *> in_selected_rows =
8184
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
8285

83-
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_,
84-
out_var_handle->place_,
86+
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
8587
out_var->GetMutable<framework::SelectedRows>());
8688
} else {
8789
std::vector<const LoDTensor *> lod_tensors =

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ class SSAGraphBuilder {
4848
const platform::Place &place,
4949
size_t place_offset);
5050

51-
static VarHandle *GetLatestVarHandle(SSAGraph *graph,
52-
const std::string &each_var_name,
53-
size_t place_offset);
54-
5551
// Add an output variable (each_var_name, place, place_offset) to op_handle,
5652
// which belongs to graph
5753
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,

paddle/fluid/framework/details/var_handle.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,10 @@ struct VarHandle : public VarHandleBase {
6262
std::string name_;
6363
platform::Place place_;
6464

65-
// NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four
66-
// member variables(version_, scope_id_, name_, place_) must be equal. But
67-
// sometimes judging whether the two var_handle is equal is actually to
68-
// determine whether the two Variables that represented by var_handle is the
69-
// same. And the same Variable may have many different var_handles, the
70-
// version_ of these var_handles is different. So I don't take care of
71-
// version_ temporarily when overloading equal.
72-
bool operator==(const VarHandle& o) const {
65+
bool IsTheSameVar(const VarHandle& o) const {
7366
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
7467
o.scope_idx_ == scope_idx_;
7568
}
76-
77-
bool operator!=(const VarHandle& o) const { return !this->operator==(o); }
7869
};
7970

8071
// Dummy Variable. It is used to represent dependencies between operators

paddle/fluid/framework/details/variable_visitor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
8888
VisitVariable(src, &visitor);
8989
}
9090

91-
struct EnforceEqualShapeAndDTypeVisitor {
91+
struct EnforceShapeAndDTypeEQVisitor {
9292
const Variable* trg_;
9393

9494
void operator()(const LoDTensor& src) {
@@ -130,7 +130,7 @@ struct EnforceEqualShapeAndDTypeVisitor {
130130

131131
void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1,
132132
const Variable& var2) {
133-
EnforceEqualShapeAndDTypeVisitor visitor{&var1};
133+
EnforceShapeAndDTypeEQVisitor visitor{&var1};
134134
VisitVariable(var2, &visitor);
135135
}
136136

0 commit comments

Comments
 (0)