Skip to content

Commit 13de723

Browse files
committed
Fix broadcast
1 parent 28a86ae commit 13de723

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
7373
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
7474
std::vector<std::function<void()>> broadcast_calls;
7575

76+
int type = platform::ToNCCLDataType(in_tensor.type());
77+
size_t numel = static_cast<size_t>(in_tensor.numel());
78+
7679
for (auto out_var_handle : out_var_handles) {
7780
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
7881
->FindVar(out_var_handle->name_);
@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
8790
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
8891
out_handle = out_var_handle;
8992
} else {
90-
send_recv_buffer =
91-
VariableVisitor::GetMutableTensor(out_var).mutable_data(
92-
out_var_handle->place_);
93+
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
94+
.Resize(in_tensor.dims())
95+
.mutable_data(out_var_handle->place_);
9396
}
9497

95-
int type = platform::ToNCCLDataType(in_tensor.type());
96-
size_t numel = static_cast<size_t>(in_tensor.numel());
9798
broadcast_calls.emplace_back(
9899
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
99100
PADDLE_ENFORCE(platform::dynload::ncclBcast(

0 commit comments

Comments
 (0)