15
15
#include < algorithm>
16
16
#include < limits>
17
17
#include < vector>
18
+ #include " paddle/fluid/framework/data_type.h"
18
19
19
20
namespace paddle {
20
21
namespace framework {
@@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
261
262
os.write (out.data (), size);
262
263
}
263
264
{ // the 3rd field, tensor data
264
- uint64_t size = tensor.memory_size ();
265
+ uint64_t size = tensor.numel () * framework::SizeOfType (tensor.type ());
266
+
265
267
auto * data_ptr = tensor.data <void >();
266
268
PADDLE_ENFORCE (size < std::numeric_limits<std::streamsize>::max (),
267
269
" Index overflow when writing tensor" );
@@ -331,14 +333,17 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
331
333
tensor->Resize (framework::make_ddim (dims));
332
334
void * buf;
333
335
auto ctx = platform::CPUDeviceContext ();
336
+ size_t size =
337
+ tensor->numel () *
338
+ framework::SizeOfType (framework::ToTypeIndex (desc.data_type ()));
334
339
if (platform::is_gpu_place (dev_ctx.GetPlace ())) {
335
340
#ifdef PADDLE_WITH_CUDA
336
341
Tensor cpu_tensor;
337
342
cpu_tensor.Resize (framework::make_ddim (dims));
338
343
framework::VisitDataType (
339
344
desc.data_type (),
340
345
DeserializedDataFunctor (&buf, &cpu_tensor, ctx.GetPlace ()));
341
- is.read (static_cast <char *>(buf), cpu_tensor. memory_size () );
346
+ is.read (static_cast <char *>(buf), size );
342
347
auto dst_place = dev_ctx.GetPlace ();
343
348
framework::TensorCopy (cpu_tensor, dst_place, dev_ctx, tensor);
344
349
#else
@@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
348
353
framework::VisitDataType (
349
354
desc.data_type (),
350
355
DeserializedDataFunctor (&buf, tensor, ctx.GetPlace ()));
351
- is.read (static_cast <char *>(buf), tensor-> memory_size () );
356
+ is.read (static_cast <char *>(buf), size );
352
357
}
353
358
}
354
359
}
0 commit comments