@@ -58,6 +58,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
58
58
int64_t before = dst_stride_numel[0 ] / dst_stride_numel[axis];
59
59
int64_t src_after = src_stride_numel[axis];
60
60
int64_t dst_after = dst_stride_numel[axis];
61
+ int64_t copy_size = std::min (src_after, dst_after);
61
62
auto place = ctx.GetPlace ();
62
63
63
64
PADDLE_ENFORCE_EQ (src_stride_numel.size (), dst_stride_numel.size (),
@@ -82,14 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
82
83
if (platform::is_cpu_place (place)) {
83
84
auto & cpu_place = boost::get<platform::CPUPlace>(place);
84
85
memory::Copy (cpu_place, dst + i * dst_after, cpu_place,
85
- src + i * src_after, sizeof (T) * src_after );
86
+ src + i * src_after, sizeof (T) * copy_size );
86
87
} else {
87
88
#ifdef PADDLE_WITH_CUDA
88
89
auto & gpu_place = boost::get<platform::CUDAPlace>(place);
89
90
auto & cuda_ctx =
90
91
reinterpret_cast <const platform::CUDADeviceContext&>(ctx);
91
92
memory::Copy (gpu_place, dst + i * dst_after, gpu_place,
92
- src + i * src_after, sizeof (T) * src_after ,
93
+ src + i * src_after, sizeof (T) * copy_size ,
93
94
cuda_ctx.stream ());
94
95
#else
95
96
PADDLE_THROW (" Paddle is not compiled with GPU" );
0 commit comments