Skip to content

Commit e5fe893

Browse files
author
Yancey
authored
send_recv variables (#7161)
* send_recv variable * delete unused logs * fix ci failed * update * resize tensor before tensor copy * add selectedrows unit test * check rows
1 parent 7508d52 commit e5fe893

16 files changed

+249
-118
lines changed

paddle/framework/lod_tensor.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,10 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
217217
SerializeToStream(os, static_cast<Tensor>(tensor), dev_ctx);
218218
}
219219

220-
void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
220+
void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
221+
const platform::DeviceContext &dev_ctx) {
221222
{
222-
// the 1st field, unit32_t version for SelectedRows
223+
// the 1st field, unit32_t version for LoDTensor
223224
uint32_t version;
224225
is.read(reinterpret_cast<char *>(&version), sizeof(version));
225226
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
@@ -240,7 +241,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
240241
}
241242
}
242243
// the 3st filed, Tensor
243-
DeserializeFromStream(is, static_cast<Tensor *>(tensor));
244+
DeserializeFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
244245
}
245246

246247
} // namespace framework

paddle/framework/lod_tensor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ void AppendLoD(LoD* lod, const LoD& lod_length);
208208
*/
209209
void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
210210
const platform::DeviceContext& dev_ctx);
211-
void DeserializeFromStream(std::istream& is, LoDTensor* tensor);
211+
void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
212+
const platform::DeviceContext& dev_ctx);
212213

213214
} // namespace framework
214215
} // namespace paddle

paddle/framework/lod_tensor_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ TEST_F(LoDTensorTester, SerializeAndDeserialize) {
132132
std::ostringstream oss;
133133
SerializeToStream(oss, lod_tensor_, cpu_ctx);
134134
std::istringstream iss(oss.str());
135-
DeserializeFromStream(iss, &dst_tensor);
135+
DeserializeFromStream(iss, &dst_tensor, cpu_ctx);
136136
float* dst_ptr = dst_tensor.mutable_data<float>(platform::CPUPlace());
137137
for (int i = 0; i < kLodTensorSize; ++i) {
138138
EXPECT_EQ(dst_ptr[i], i);

paddle/framework/selected_rows.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
3737
SerializeToStream(os, selected_rows.value(), dev_ctx);
3838
}
3939

40-
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) {
41-
auto tensor = *selected_rows->mutable_value();
40+
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
41+
const platform::DeviceContext& dev_ctx) {
4242
{
4343
// the 1st field, unit32_t version for SelectedRows
4444
uint32_t version;
@@ -62,7 +62,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) {
6262
selected_rows->set_height(height);
6363
}
6464
// the 4st field, tensor which contains the data
65-
DeserializeFromStream(is, &tensor);
65+
DeserializeFromStream(is, selected_rows->mutable_value(), dev_ctx);
6666
}
6767

6868
} // namespace framework

paddle/framework/selected_rows.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ class SelectedRows {
6666
*/
6767
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
6868
const platform::DeviceContext& dev_ctx);
69-
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows);
69+
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
70+
const platform::DeviceContext& dev_ctx);
7071

7172
} // namespace framework
7273
} // namespace paddle

paddle/framework/selected_rows_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
5151
SerializeToStream(oss, *selected_rows_, cpu_ctx);
5252

5353
std::istringstream iss(oss.str());
54-
DeserializeFromStream(iss, &dst_tensor);
54+
DeserializeFromStream(iss, &dst_tensor, cpu_ctx);
5555

5656
ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows());
5757
ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
58+
ASSERT_EQ(selected_rows_->value().dims(), dst_tensor.value().dims());
59+
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
5860
}
5961

6062
} // namespace framework

paddle/framework/tensor_util.h

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,23 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
270270
}
271271
}
272272

273-
inline void DeserializeFromStream(std::istream& is, Tensor* tensor) {
273+
struct DeserializedDataFunctor {
274+
DeserializedDataFunctor(void** buf, Tensor* tensor,
275+
const platform::Place& place)
276+
: buf_(buf), tensor_(tensor), place_(place) {}
277+
278+
template <typename T>
279+
void operator()() {
280+
*buf_ = tensor_->mutable_data<T>(place_);
281+
}
282+
283+
void** buf_;
284+
Tensor* tensor_;
285+
platform::Place place_;
286+
};
287+
288+
inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
289+
const platform::DeviceContext& dev_ctx) {
274290
uint32_t version;
275291
is.read(reinterpret_cast<char*>(&version), sizeof(version));
276292
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
@@ -289,27 +305,28 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor) {
289305
dims.reserve(static_cast<size_t>(desc.dims().size()));
290306
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
291307
tensor->Resize(framework::make_ddim(dims));
292-
293308
void* buf;
294-
platform::Place cpu = platform::CPUPlace();
295-
// TODO(Yancey1989): use VisiterDataType instead of DataType switch
296-
switch (desc.data_type()) {
297-
case proto::FP32:
298-
buf = tensor->mutable_data<float>(cpu);
299-
break;
300-
case proto::FP64:
301-
buf = tensor->mutable_data<double>(cpu);
302-
break;
303-
case proto::INT32:
304-
buf = tensor->mutable_data<int>(cpu);
305-
break;
306-
case proto::INT64:
307-
buf = tensor->mutable_data<int64_t>(cpu);
308-
break;
309-
default:
310-
PADDLE_THROW("DataType %d not supported", desc.data_type());
309+
auto ctx = platform::CPUDeviceContext();
310+
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
311+
#ifdef PADDLE_WITH_CUDA
312+
Tensor cpu_tensor;
313+
cpu_tensor.Resize(framework::make_ddim(dims));
314+
framework::VisitDataType(
315+
desc.data_type(),
316+
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
317+
is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
318+
auto cpu_place = new platform::CPUPlace();
319+
framework::CopyFrom(cpu_tensor, *cpu_place, dev_ctx, tensor);
320+
delete cpu_place;
321+
#else
322+
PADDLE_THROW("Unexpected branch");
323+
#endif
324+
} else {
325+
framework::VisitDataType(
326+
desc.data_type(),
327+
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
328+
is.read(static_cast<char*>(buf), tensor->memory_size());
311329
}
312-
is.read(static_cast<char*>(buf), tensor->memory_size());
313330
}
314331
}
315332

paddle/framework/tensor_util_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,12 @@ TEST(Tensor, SerializeAndDeserialize) {
270270
SerializeToStream(oss, src_tensor, cpu_ctx);
271271

272272
std::istringstream iss(oss.str());
273-
DeserializeFromStream(iss, &dst_tensor);
273+
DeserializeFromStream(iss, &dst_tensor, cpu_ctx);
274274
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
275275
for (int i = 0; i < 5; ++i) {
276276
ASSERT_EQ(dst_ptr[i], array[i]);
277277
}
278+
ASSERT_EQ(dst_tensor.dims(), src_tensor.dims());
278279
delete place;
279280
}
280281
#ifdef PADDLE_WITH_CUDA
@@ -292,13 +293,12 @@ TEST(Tensor, SerializeAndDeserialize) {
292293
SerializeToStream(oss, gpu_tensor, gpu_ctx);
293294

294295
std::istringstream iss(oss.str());
295-
DeserializeFromStream(iss, &dst_tensor);
296+
DeserializeFromStream(iss, &dst_tensor, gpu_ctx);
296297

297298
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
298299
for (int i = 0; i < 6; ++i) {
299300
ASSERT_EQ(dst_ptr[i], array[i]);
300301
}
301-
302302
delete gpu_place;
303303
}
304304
#endif

paddle/framework/var_type.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License. */
1717
#include "paddle/framework/lod_rank_table.h"
1818
#include "paddle/framework/lod_tensor.h"
1919
#include "paddle/framework/lod_tensor_array.h"
20+
#include "paddle/framework/selected_rows.h"
21+
#include "paddle/framework/variable.h"
2022

2123
namespace paddle {
2224
namespace framework {
@@ -35,7 +37,7 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
3537
}
3638

3739
template <typename Visitor>
38-
inline void VisitVarType(const Variable& var, Visitor visitor) {
40+
inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
3941
switch (ToVarType(var.Type())) {
4042
case proto::VarDesc_VarType_LOD_TENSOR:
4143
visitor(var.Get<framework::LoDTensor>());

paddle/operators/detail/recv_impl.cc

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,9 @@ namespace detail {
2121
Status SendRecvServerImpl::SendVariable(ServerContext *context,
2222
const VariableMessage *in_var,
2323
VoidMessage *out_var) {
24-
// TODO(typhoonzero): support different variable types.
25-
std::istringstream iss(in_var->serialized());
26-
framework::LoDTensor t;
27-
framework::DeserializeFromStream(iss, &t);
28-
TensorWithName tensor_with_name =
29-
std::make_pair(in_var->varname(), std::move(t));
30-
31-
var_recv_queue_.Push(std::move(tensor_with_name));
24+
MessageWithName msg_with_name =
25+
std::make_pair(in_var->varname(), std::move(*in_var));
26+
var_recv_queue_.Push(std::move(msg_with_name));
3227
return Status::OK;
3328
}
3429

@@ -37,14 +32,8 @@ Status SendRecvServerImpl::GetVariable(ServerContext *context,
3732
VariableMessage *out_var) {
3833
std::string get_var_name = in_var->varname();
3934
auto *var = scope_->FindVar(get_var_name);
40-
auto tensor = var->Get<framework::LoDTensor>();
41-
std::ostringstream oss;
42-
framework::SerializeToStream(oss, tensor, platform::CPUDeviceContext());
4335

44-
std::string *varname = out_var->mutable_varname();
45-
*varname = get_var_name;
46-
std::string *serialized = out_var->mutable_serialized();
47-
*serialized = oss.str();
36+
SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var);
4837
return Status::OK;
4938
}
5039

0 commit comments

Comments
 (0)