Skip to content

Commit 3498434

Browse files
committed
fix ci
1 parent 31f598f commit 3498434

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

paddle/operators/concat_op.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ class ConcatKernel : public framework::OpKernel<T> {
3737
size_t output_offset = 0;
3838
for (auto* in : ins) {
3939
auto in_stride = framework::stride_numel(in->dims());
40-
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>() + output_offset,
41-
out_stride, in->data<T>(), in_stride);
40+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
41+
out->data<T>() + output_offset, out_stride,
42+
in->data<T>(), in_stride);
4243
output_offset += in_stride[axis];
4344
}
4445
}
@@ -57,8 +58,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
5758
for (auto& out : outs) {
5859
out->mutable_data<T>(ctx.GetPlace());
5960
auto out_stride = framework::stride_numel(out->dims());
60-
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
61-
in->data<T>() + input_offset, in_stride);
61+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
62+
out_stride, in->data<T>() + input_offset,
63+
in_stride);
6264
input_offset += out_stride[axis];
6365
}
6466
}

paddle/operators/split_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ class SplitOpKernel : public framework::OpKernel<T> {
3737
for (auto& out : outs) {
3838
out->mutable_data<T>(ctx.GetPlace());
3939
auto out_stride = framework::stride_numel(out->dims());
40-
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
41-
in->data<T>() + input_offset, in_stride);
40+
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
41+
out_stride, in->data<T>() + input_offset,
42+
in_stride);
4243
input_offset += out_stride[axis];
4344
}
4445
}

paddle/operators/strided_memcpy.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
5050
// NOTE: The src and dst tensor should have the same elements
5151
// except the specified axis.
5252
template <typename T>
53-
inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx,
53+
inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
5454
int64_t axis, T* dst,
5555
const framework::DDim& dst_stride_numel,
5656
const T* src,
@@ -88,7 +88,7 @@ inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx,
8888
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
8989
auto& cuda_ctx =
9090
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
91-
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
91+
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
9292
src + i * src_after, sizeof(T) * src_after,
9393
cuda_ctx.stream());
9494
#else

0 commit comments

Comments
 (0)