Skip to content

Commit 6c981e7

Browse files
authored
Merge pull request #12259 from reyoung/feature/fix_serialize_deserialize_bug
Fix deserialize bug
2 parents be04fbf + 47ad8d4 commit 6c981e7

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

paddle/fluid/framework/tensor_util.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <algorithm>
1616
#include <limits>
1717
#include <vector>
18+
#include "paddle/fluid/framework/data_type.h"
1819

1920
namespace paddle {
2021
namespace framework {
@@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
261262
os.write(out.data(), size);
262263
}
263264
{ // the 3rd field, tensor data
264-
uint64_t size = tensor.memory_size();
265+
uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type());
266+
265267
auto* data_ptr = tensor.data<void>();
266268
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
267269
"Index overflow when writing tensor");
@@ -331,14 +333,17 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
331333
tensor->Resize(framework::make_ddim(dims));
332334
void* buf;
333335
auto ctx = platform::CPUDeviceContext();
336+
size_t size =
337+
tensor->numel() *
338+
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
334339
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
335340
#ifdef PADDLE_WITH_CUDA
336341
Tensor cpu_tensor;
337342
cpu_tensor.Resize(framework::make_ddim(dims));
338343
framework::VisitDataType(
339344
desc.data_type(),
340345
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
341-
is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
346+
is.read(static_cast<char*>(buf), size);
342347
auto dst_place = dev_ctx.GetPlace();
343348
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
344349
#else
@@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
348353
framework::VisitDataType(
349354
desc.data_type(),
350355
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
351-
is.read(static_cast<char*>(buf), tensor->memory_size());
356+
is.read(static_cast<char*>(buf), size);
352357
}
353358
}
354359
}

0 commit comments

Comments
 (0)