Skip to content

Commit a6c11a5

Browse files
committed
Fix bug in CUDA
1 parent 45530c7 commit a6c11a5

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

paddle/fluid/operators/random_crop_op.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
6767
}
6868
} else {
6969
x += offset_i * x_stride;
70-
for (size_t j = 0; j < x_dim_i; ++j) {
70+
for (size_t j = 0; j < out_dim_i; ++j) {
7171
StridedMemcpy<T>(x, x_dims, out, out_dims, i + 1, rank, x_stride,
7272
out_stride, offsets);
7373
x += x_stride;
@@ -86,8 +86,6 @@ struct RandomCropFunctor {
8686
int rank_;
8787
int64_t seed_;
8888

89-
size_t prod_x_dims_;
90-
size_t prod_out_dims_;
9189
size_t prod_batchsize_dims_;
9290
size_t prod_x_ins_dims_;
9391
size_t prod_out_ins_dims_;
@@ -118,8 +116,6 @@ struct RandomCropFunctor {
118116
prod_out_ins_dims_ *= out_dim_i;
119117
}
120118
}
121-
prod_x_dims_ = prod_batchsize_dims_ * prod_x_ins_dims_;
122-
prod_out_dims_ = prod_batchsize_dims_ * prod_out_ins_dims_;
123119
}
124120

125121
HOSTDEVICE void operator()(size_t ins_idx) {
@@ -146,7 +142,17 @@ template <typename DeviceContext, typename T>
146142
class RandomCropKernel : public framework::OpKernel<T> {
147143
public:
148144
virtual void Compute(const framework::ExecutionContext& ctx) const {
149-
int64_t seed = *ctx.Input<framework::LoDTensor>("Seed")->data<int64_t>();
145+
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
146+
int64_t seed = 0;
147+
if (platform::is_cpu_place(seed_tensor.place())) {
148+
seed = *seed_tensor.data<int64_t>();
149+
} else {
150+
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
151+
"your program";
152+
framework::LoDTensor cpu_seed;
153+
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
154+
seed = *cpu_seed.data<int64_t>();
155+
}
150156
auto shape = ctx.Attr<std::vector<int>>("shape");
151157
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
152158
auto& out = detail::Ref(ctx.Output<framework::LoDTensor>("Out"));

0 commit comments

Comments
 (0)