Skip to content

Commit 9a05c90

Browse files
committed
fix StridedNumelCopyWithAxis
1 parent 01f4bcb commit 9a05c90

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

paddle/fluid/operators/strided_memcpy.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
5858
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
5959
int64_t src_after = src_stride_numel[axis];
6060
int64_t dst_after = dst_stride_numel[axis];
61+
int64_t copy_size = std::min(src_after, dst_after);
6162
auto place = ctx.GetPlace();
6263

6364
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
@@ -82,14 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
8283
if (platform::is_cpu_place(place)) {
8384
auto& cpu_place = boost::get<platform::CPUPlace>(place);
8485
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
85-
src + i * src_after, sizeof(T) * src_after);
86+
src + i * src_after, sizeof(T) * copy_size);
8687
} else {
8788
#ifdef PADDLE_WITH_CUDA
8889
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
8990
auto& cuda_ctx =
9091
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
9192
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
92-
src + i * src_after, sizeof(T) * src_after,
93+
src + i * src_after, sizeof(T) * copy_size,
9394
cuda_ctx.stream());
9495
#else
9596
PADDLE_THROW("Paddle is not compiled with GPU");

0 commit comments

Comments
 (0)