Skip to content

Commit 4760ac4

Browse files
committed
check the generate_op is null or not and add DEPS of broadcast_op_handle and gather_op_handle
1 parent d24ef93 commit 4760ac4

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor
2121
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
2222
simple_threadpool device_context)
2323

24-
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
25-
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
26-
2724
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
2825

26+
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
27+
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)
2928

3029
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
3130
device_context broadcast_op_handle)

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ void BroadcastOpHandle::RunImpl() {
6161
"Places must be all on CPU or all on CUDA.");
6262

6363
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
64-
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
65-
in_tensor.type());
64+
VariableVisitor::GetMutableTensor(out_var)
65+
.Resize(in_tensor.dims())
66+
.mutable_data(out_p, in_tensor.type());
6667

6768
auto dev_ctx = dev_ctxes_[out_p];
6869
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
@@ -74,8 +75,10 @@ void BroadcastOpHandle::RunImpl() {
7475
}
7576

7677
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
77-
for (auto &pair : dev_ctxes_) {
78-
in_var.generated_op_->Wait(pair.second);
78+
if (in_var.generated_op_) {
79+
for (auto &pair : dev_ctxes_) {
80+
in_var.generated_op_->Wait(pair.second);
81+
}
7982
}
8083
}
8184

0 commit comments

Comments
 (0)