14
14
15
15
#pragma once
16
16
17
+ #include < vector>
17
18
#include " paddle/fluid/framework/op_registry.h"
18
19
#include " paddle/fluid/operators/detail/safe_ref.h"
19
20
#include " paddle/fluid/platform/device_context.h"
20
21
#include " paddle/fluid/platform/for_range.h"
21
- #include " thrust/random.h"
22
+ #ifdef PADDLE_WITH_CUDA
23
+ #include < thrust/random.h>
24
+ #endif
22
25
23
26
namespace paddle {
24
27
namespace operators {
@@ -34,36 +37,39 @@ struct Random<platform::CPUDeviceContext> {
34
37
using UniformIntDist = std::uniform_int_distribution<T>;
35
38
};
36
39
40
+ #ifdef PADDLE_WITH_CUDA
37
41
template <>
38
42
struct Random <platform::CUDADeviceContext> {
39
43
using Engine = thrust::minstd_rand;
40
44
41
45
template <typename T>
42
46
using UniformIntDist = thrust::uniform_int_distribution<T>;
43
47
};
48
+ #endif
44
49
45
50
template <typename T>
46
- HOSTDEVICE inline void RandomCropImpl (const T* x, size_t * x_dim, T* out,
47
- size_t * out_dim, int i, int rank,
48
- int64_t prod_x_remain,
49
- int64_t prod_out_remain, size_t * offset) {
50
- size_t x_length = x_dim[rank];
51
- size_t out_length = out_dim[rank];
52
-
53
- int64_t x_stride = prod_x_remain / x_length;
54
- int64_t out_stride = prod_out_remain / out_length;
55
- size_t offset_i = offset[i];
56
- if (x_stride == 1 && out_stride == 1 ) {
57
- // In the final stage, copy from offset.
51
+ HOSTDEVICE inline void StridedMemcpy (const T* x, const size_t * x_dims, T* out,
52
+ const size_t * out_dims, int i, int rank,
53
+ size_t prod_x_remain,
54
+ size_t prod_out_remain,
55
+ const size_t * offsets) {
56
+ size_t x_dim_i = x_dims[i];
57
+ size_t out_dim_i = out_dims[i];
58
+ size_t x_stride = prod_x_remain / x_dim_i;
59
+ size_t out_stride = prod_out_remain / out_dim_i;
60
+ size_t offset_i = offsets[i];
61
+
62
+ if (i == rank - 1 ) {
63
+ PADDLE_ENFORCE (x_stride == 1 && out_stride == 1 );
58
64
x += offset_i;
59
- for (size_t i = 0 ; i < out_length ; ++i ) {
65
+ for (size_t j = 0 ; j < out_dim_i ; ++j ) {
60
66
*out++ = *x++;
61
67
}
62
68
} else {
63
69
x += offset_i * x_stride;
64
- for (size_t i = 0 ; i < out_length ; ++i ) {
65
- RandomCropImpl <T>(x, x_dim , out, out_dim , i + 1 , rank, x_stride,
66
- out_stride, offset );
70
+ for (size_t j = 0 ; j < x_dim_i ; ++j ) {
71
+ StridedMemcpy <T>(x, x_dims , out, out_dims , i + 1 , rank, x_stride,
72
+ out_stride, offsets );
67
73
x += x_stride;
68
74
out += out_stride;
69
75
}
@@ -74,94 +80,96 @@ template <typename DeviceContext, typename T>
74
80
struct RandomCropFunctor {
75
81
const T* x_;
76
82
T* out_;
77
- size_t x_dim_[9 ];
78
- size_t out_dim_[9 ];
79
- size_t prod_same_dim_;
80
-
81
- size_t prod_x_dim_;
82
- size_t prod_out_dim_;
83
-
84
- int num_same_dim_;
83
+ size_t x_dims_[9 ];
84
+ size_t out_dims_[9 ];
85
+ int num_batchsize_dims_;
85
86
int rank_;
86
-
87
87
int64_t seed_;
88
88
89
- RandomCropFunctor (const T* x, T* out, int64_t seed)
89
+ size_t prod_x_dims_;
90
+ size_t prod_out_dims_;
91
+ size_t prod_batchsize_dims_;
92
+ size_t prod_x_ins_dims_;
93
+ size_t prod_out_ins_dims_;
94
+
95
+ RandomCropFunctor (const T* x, T* out, const framework::DDim& x_dims,
96
+ const framework::DDim& out_dims, int num_batchsize_dims,
97
+ int64_t seed)
90
98
: x_(x),
91
99
out_ (out),
92
- prod_same_dim_(1 ),
93
- prod_x_dim_(1 ),
94
- prod_out_dim_(1 ),
100
+ num_batchsize_dims_(num_batchsize_dims),
101
+ rank_(x_dims.size()),
95
102
seed_(seed) {
96
- std::fill (x_dim_, x_dim_ + sizeof (x_dim_) / sizeof (size_t ), 0 );
97
- std::fill (out_dim_, out_dim_ + sizeof (out_dim_) / sizeof (size_t ), 0 );
103
+ PADDLE_ENFORCE_EQ (x_dims.size (), out_dims.size ());
104
+ PADDLE_ENFORCE_GT (rank_, num_batchsize_dims_);
105
+ prod_batchsize_dims_ = 1 ;
106
+ prod_x_ins_dims_ = 1 ;
107
+ prod_out_ins_dims_ = 1 ;
108
+ for (size_t i = 0 ; i < rank_; ++i) {
109
+ size_t x_dim_i = x_dims[i];
110
+ size_t out_dim_i = out_dims[i];
111
+ x_dims_[i] = x_dim_i;
112
+ out_dims_[i] = out_dim_i;
113
+ if (i < num_batchsize_dims_) {
114
+ PADDLE_ENFORCE_EQ (x_dim_i, out_dim_i);
115
+ prod_batchsize_dims_ *= x_dim_i;
116
+ } else {
117
+ prod_x_ins_dims_ *= x_dim_i;
118
+ prod_out_ins_dims_ *= out_dim_i;
119
+ }
120
+ }
121
+ prod_x_dims_ = prod_batchsize_dims_ * prod_x_ins_dims_;
122
+ prod_out_dims_ = prod_batchsize_dims_ * prod_out_ins_dims_;
98
123
}
99
124
100
- HOSTDEVICE void operator ()(size_t i ) {
125
+ HOSTDEVICE void operator ()(size_t ins_idx ) {
101
126
typename Random<DeviceContext>::Engine engine (seed_);
102
- engine.discard (i * (rank_ - num_same_dim_));
103
-
104
- int64_t prod_x_unsame = (prod_x_dim_ / prod_same_dim_);
105
- int64_t prod_out_unsame = (prod_out_dim_ / prod_same_dim_);
106
-
107
- const T* x = x_ + i * prod_x_unsame;
108
- T* out = out_ + i * prod_out_unsame;
109
-
110
- size_t offset[9 ];
111
- for (int i = num_same_dim_; i < rank_; ++i) {
127
+ engine.discard (ins_idx * (rank_ - num_batchsize_dims_));
128
+ size_t offsets[9 ];
129
+ for (int i = num_batchsize_dims_; i < rank_; ++i) {
112
130
typename Random<DeviceContext>::template UniformIntDist<size_t > dist (
113
- 0 , x_dim_ [i] - out_dim_ [i]);
114
- offset [i] = dist (engine);
131
+ 0 , x_dims_ [i] - out_dims_ [i]);
132
+ offsets [i] = dist (engine);
115
133
}
116
- RandomCropImpl<T>(x, x_dim_, out, out_dim_, num_same_dim_, rank_,
117
- prod_x_unsame, prod_out_unsame, offset);
134
+
135
+ const T* x = x_ + ins_idx * prod_x_ins_dims_;
136
+ T* out = out_ + ins_idx * prod_out_ins_dims_;
137
+
138
+ StridedMemcpy<T>(x, x_dims_ + num_batchsize_dims_, out,
139
+ out_dims_ + num_batchsize_dims_, 0 ,
140
+ rank_ - num_batchsize_dims_, prod_x_ins_dims_,
141
+ prod_out_ins_dims_, offsets);
118
142
}
119
143
};
120
144
121
145
template <typename DeviceContext, typename T>
122
146
class RandomCropKernel : public framework ::OpKernel<T> {
123
147
public:
124
- virtual void Compute (const framework::ExecutionContext& context) const {
125
- int64_t seed =
126
- *context.Input <framework::LoDTensor>(" Seed" )->data <int64_t >();
127
- auto & x = detail::Ref (context.Input <framework::LoDTensor>(" X" ));
128
- auto & out = detail::Ref (context.Output <framework::LoDTensor>(" Out" ));
129
-
130
- RandomCropFunctor<DeviceContext, T> functor{
131
- x.data <T>(), out.mutable_data <T>(context.GetPlace ()), seed};
132
-
133
- auto & out_dim = out.dims ();
134
- auto & x_dim = x.dims ();
135
-
136
- auto rank = x_dim.size ();
137
- while (rank-- > 0 ) {
138
- functor.x_dim_ [rank] = x_dim[rank];
139
- functor.out_dim_ [rank] = out_dim[rank];
140
- functor.prod_x_dim_ *= x_dim[rank];
141
- functor.prod_out_dim_ *= out_dim[rank];
142
- if (x_dim[rank] != out_dim[rank]) {
143
- PADDLE_ENFORCE_EQ (functor.prod_same_dim_ , 1 );
144
- functor.num_same_dim_ = rank;
145
- } else {
146
- functor.prod_same_dim_ *= out_dim[rank];
147
- }
148
- }
149
- functor.rank_ = x_dim.size ();
150
-
148
+ virtual void Compute (const framework::ExecutionContext& ctx) const {
149
+ int64_t seed = *ctx.Input <framework::LoDTensor>(" Seed" )->data <int64_t >();
150
+ auto shape = ctx.Attr <std::vector<int >>(" shape" );
151
+ auto & x = detail::Ref (ctx.Input <framework::LoDTensor>(" X" ));
152
+ auto & out = detail::Ref (ctx.Output <framework::LoDTensor>(" Out" ));
153
+
154
+ int num_batchsize_dims = x.dims ().size () - shape.size ();
155
+ RandomCropFunctor<DeviceContext, T> functor (
156
+ x.data <T>(), out.mutable_data <T>(ctx.GetPlace ()), x.dims (), out.dims (),
157
+ num_batchsize_dims, seed);
151
158
platform::ForRange<DeviceContext> for_range (
152
- context .template device_context <DeviceContext>(),
153
- functor.prod_same_dim_ );
159
+ ctx .template device_context <DeviceContext>(),
160
+ functor.prod_batchsize_dims_ );
154
161
155
162
for_range (functor);
156
163
157
164
Random<platform::CPUDeviceContext>::Engine engine (seed);
158
- engine.discard (functor.prod_same_dim_ *
159
- (functor.rank_ - functor.num_same_dim_ ));
160
-
161
- *context.Output <framework::LoDTensor>(" SeedOut" )->mutable_data <int64_t >(
165
+ engine.discard (functor.prod_batchsize_dims_ *
166
+ (functor.rank_ - functor.num_batchsize_dims_ ));
167
+ *ctx.Output <framework::LoDTensor>(" SeedOut" )->mutable_data <int64_t >(
162
168
platform::CPUPlace ()) = engine ();
163
169
}
164
170
};
165
171
172
+ // TODO(fengjiayi): Backward of random crop op
173
+
166
174
} // namespace operators
167
175
} // namespace paddle
0 commit comments