@@ -12,14 +12,108 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " paddle/fluid/operators/reshape_op.h"
16
-
17
15
#include < string>
18
16
#include < vector>
17
+ #include " paddle/fluid/framework/op_registry.h"
19
18
20
19
namespace paddle {
21
20
namespace operators {
22
21
22
+ class ReshapeOp : public framework ::OperatorWithKernel {
23
+ public:
24
+ ReshapeOp (const std::string &type, const framework::VariableNameMap &inputs,
25
+ const framework::VariableNameMap &outputs,
26
+ const framework::AttributeMap &attrs)
27
+ : OperatorWithKernel(type, inputs, outputs, attrs) {}
28
+
29
+ void InferShape (framework::InferShapeContext *ctx) const override {
30
+ PADDLE_ENFORCE (ctx->HasInput (" X" ),
31
+ " Input(X) of ReshapeOp should not be null." );
32
+ PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
33
+ " Output(Out) of ReshapeOp should not be null." );
34
+
35
+ const std::vector<int > &shape = ctx->Attrs ().Get <std::vector<int >>(" shape" );
36
+ PADDLE_ENFORCE (!shape.empty (),
37
+ " The shape information must be set by Attr(shape)." );
38
+
39
+ if (ctx->HasInput (" Shape" ) && ctx->IsRuntime ()) {
40
+ // If true, set the shape of Output(Out) according to Input(Shape) in
41
+ // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
42
+ ctx->ShareLoD (" X" , /* ->*/ " Out" );
43
+ return ;
44
+ }
45
+
46
+ auto x_dims = ctx->GetInputDim (" X" );
47
+ auto out_dims = ValidateShape (shape, x_dims);
48
+ ctx->SetOutputDim (" Out" , out_dims);
49
+ if (x_dims[0 ] == out_dims[0 ]) {
50
+ // Only pass LoD when the first dimension of output and Input(X)
51
+ // are the same.
52
+ ctx->ShareLoD (" X" , /* ->*/ " Out" );
53
+ }
54
+ }
55
+
56
+ static framework::DDim ValidateShape (const std::vector<int > shape,
57
+ const framework::DDim &in_dims) {
58
+ const int64_t in_size = framework::product (in_dims);
59
+ // only one dimension can be set to -1, whose size will be automatically
60
+ // infered.
61
+ const int64_t unk_dim_val = -1 ;
62
+ const int64_t copy_dim_val = 0 ;
63
+
64
+ std::vector<int64_t > output_shape (shape.size (), 0 );
65
+ int64_t capacity = 1 ;
66
+ int unk_dim_idx = -1 ;
67
+ for (size_t i = 0 ; i < shape.size (); ++i) {
68
+ if (shape[i] == unk_dim_val) {
69
+ PADDLE_ENFORCE (
70
+ unk_dim_idx == -1 ,
71
+ " Only one input dimension of Attr(shape) can be unknown." );
72
+ unk_dim_idx = i;
73
+ } else if (shape[i] == copy_dim_val) {
74
+ PADDLE_ENFORCE (
75
+ static_cast <int >(i) < in_dims.size (),
76
+ " The index of dimension to copy from input shape must be less "
77
+ " than the size of input shape." );
78
+ } else {
79
+ PADDLE_ENFORCE (
80
+ shape[i] > 0 ,
81
+ " Each input dimension of Attr(shape) must not be negtive except "
82
+ " one unknown dimension." );
83
+ }
84
+
85
+ capacity *= (shape[i] ? shape[i] : in_dims[i]);
86
+ output_shape[i] =
87
+ (shape[i] ? static_cast <int64_t >(shape[i]) : in_dims[i]);
88
+ }
89
+
90
+ if (unk_dim_idx != -1 ) {
91
+ if (in_size > 0 ) {
92
+ // in_size < 0 and is un-determinate in compile time, skip the check,
93
+ // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
94
+ // capacity = -24, in_size = -8, output_shape[0] = 0
95
+ // the following check will fail.
96
+ output_shape[unk_dim_idx] = -in_size / capacity;
97
+ PADDLE_ENFORCE_EQ (output_shape[unk_dim_idx] * capacity, -in_size,
98
+ " Invalid shape is given." );
99
+ } else {
100
+ output_shape[unk_dim_idx] = -1 ;
101
+ }
102
+ } else {
103
+ PADDLE_ENFORCE_EQ (capacity, in_size, " Invalid shape is given." );
104
+ }
105
+ return framework::make_ddim (output_shape);
106
+ }
107
+
108
+ protected:
109
+ framework::OpKernelType GetExpectedKernelType (
110
+ const framework::ExecutionContext &ctx) const override {
111
+ return framework::OpKernelType (
112
+ framework::ToDataType (ctx.Input <framework::LoDTensor>(" X" )->type ()),
113
+ ctx.device_context ());
114
+ }
115
+ };
116
+
23
117
class ReshapeOpMaker : public framework ::OpProtoAndCheckerMaker {
24
118
public:
25
119
void Make () override {
@@ -107,64 +201,72 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
107
201
}
108
202
};
109
203
110
- void ReshapeKernel::operator ()(const framework::ExecutionContext &ctx) const {
111
- auto *out = ctx.Output <framework::LoDTensor>(" Out" );
112
- auto *in = ctx.Input <framework::LoDTensor>(" X" );
204
+ class ReshapeKernel {
205
+ public:
206
+ void operator ()(const framework::ExecutionContext &ctx) const {
207
+ auto *out = ctx.Output <framework::LoDTensor>(" Out" );
208
+ auto *in = ctx.Input <framework::LoDTensor>(" X" );
113
209
114
- auto *shape_tensor = ctx.HasInput (" Shape" )
115
- ? ctx.Input <framework::LoDTensor>(" Shape" )
116
- : nullptr ;
210
+ auto *shape_tensor = ctx.HasInput (" Shape" )
211
+ ? ctx.Input <framework::LoDTensor>(" Shape" )
212
+ : nullptr ;
117
213
118
- framework::DDim out_dims = out->dims ();
214
+ framework::DDim out_dims = out->dims ();
119
215
120
- if (shape_tensor) {
121
- auto *shape_data = shape_tensor->data <int >();
122
- framework::Tensor cpu_shape_tensor;
123
- if (platform::is_gpu_place (ctx.GetPlace ())) {
124
- TensorCopySync (*shape_tensor, platform::CPUPlace (), &cpu_shape_tensor);
125
- shape_data = cpu_shape_tensor.data <int >();
216
+ if (shape_tensor) {
217
+ auto *shape_data = shape_tensor->data <int >();
218
+ framework::Tensor cpu_shape_tensor;
219
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
220
+ TensorCopySync (*shape_tensor, platform::CPUPlace (), &cpu_shape_tensor);
221
+ shape_data = cpu_shape_tensor.data <int >();
222
+ }
223
+ auto shape =
224
+ std::vector<int >(shape_data, shape_data + shape_tensor->numel ());
225
+ out_dims = ReshapeOp::ValidateShape (shape, in->dims ());
226
+ }
227
+ if (!in->lod ().empty ()) {
228
+ PADDLE_ENFORCE_EQ (
229
+ out_dims[0 ], in->dims ()[0 ],
230
+ " Reshape operator cannot reshape an input sequence batch "
231
+ " into an output sequence batch that has a different "
232
+ " number of time steps. Please consider using "
233
+ " sequence_reshape op." );
126
234
}
127
- auto shape =
128
- std::vector<int >(shape_data, shape_data + shape_tensor->numel ());
129
- out_dims = ReshapeOp::ValidateShape (shape, in->dims ());
130
- }
131
- if (!in->lod ().empty ()) {
132
- PADDLE_ENFORCE_EQ (out_dims[0 ], in->dims ()[0 ],
133
- " Reshape operator cannot reshape an input sequence batch "
134
- " into an output sequence batch that has a different "
135
- " number of time steps. Please consider using "
136
- " sequence_reshape op." );
137
- }
138
235
139
- bool inplace = ctx.Attr <bool >(" inplace" );
140
- out->Resize (out_dims);
141
- if (!inplace) {
142
- out->mutable_data (ctx.GetPlace (), in->type ());
143
- framework::TensorCopySync (*in, ctx.GetPlace (), out);
144
- out->Resize (out_dims);
145
- } else {
146
- out->ShareDataWith (*in);
236
+ bool inplace = ctx.Attr <bool >(" inplace" );
147
237
out->Resize (out_dims);
238
+ if (!inplace) {
239
+ out->mutable_data (ctx.GetPlace (), in->type ());
240
+ framework::TensorCopySync (*in, ctx.GetPlace (), out);
241
+ out->Resize (out_dims);
242
+ } else {
243
+ out->ShareDataWith (*in);
244
+ out->Resize (out_dims);
245
+ }
148
246
}
149
- }
150
- void ReshapeGradKernel::operator ()(
151
- const framework::ExecutionContext &ctx) const {
152
- auto *d_out = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
153
- auto *d_x = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
154
-
155
- d_x->mutable_data (ctx.GetPlace (), d_out->type ());
156
- bool inplace = ctx.Attr <bool >(" inplace" );
157
-
158
- auto in_dims = d_x->dims ();
159
- if (!inplace) {
160
- framework::TensorCopy (*d_out, ctx.GetPlace (), ctx.device_context (), d_x);
161
- ctx.device_context ().Wait ();
162
- d_x->Resize (in_dims);
163
- } else {
164
- d_x->ShareDataWith (*d_out);
165
- d_x->Resize (in_dims);
247
+ };
248
+
249
+ class ReshapeGradKernel {
250
+ public:
251
+ void operator ()(const framework::ExecutionContext &ctx) const {
252
+ auto *d_out = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
253
+ auto *d_x = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
254
+
255
+ d_x->mutable_data (ctx.GetPlace (), d_out->type ());
256
+ bool inplace = ctx.Attr <bool >(" inplace" );
257
+
258
+ auto in_dims = d_x->dims ();
259
+ if (!inplace) {
260
+ framework::TensorCopy (*d_out, ctx.GetPlace (), ctx.device_context (), d_x);
261
+ ctx.device_context ().Wait ();
262
+ d_x->Resize (in_dims);
263
+ } else {
264
+ d_x->ShareDataWith (*d_out);
265
+ d_x->Resize (in_dims);
266
+ }
166
267
}
167
- }
268
+ };
269
+
168
270
} // namespace operators
169
271
} // namespace paddle
170
272
namespace ops = paddle::operators;
@@ -179,3 +281,13 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
179
281
double , ops::ReshapeGradKernel, int ,
180
282
ops::ReshapeGradKernel, int64_t ,
181
283
ops::ReshapeGradKernel);
284
+
285
+ #ifdef PADDLE_WITH_CUDA
286
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel, double ,
287
+ ops::ReshapeKernel, int , ops::ReshapeKernel,
288
+ int64_t , ops::ReshapeKernel);
289
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape_grad, float , ops::ReshapeGradKernel,
290
+ double , ops::ReshapeGradKernel, int ,
291
+ ops::ReshapeGradKernel, int64_t ,
292
+ ops::ReshapeGradKernel);
293
+ #endif
0 commit comments