Skip to content

Commit 8868525

Browse files
author
chengduo
authored
Refine reshape_grad and transpose_grad (#13074)
* Add intermediate * fix flatten/squeeze/unsqueeze * Considering compatibility issues, we could not fix the origin op * follow comment * reset the shape of XShape
1 parent 7dd8adb commit 8868525

File tree

13 files changed

+650
-139
lines changed

13 files changed

+650
-139
lines changed

paddle/fluid/operators/flatten_op.cc

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,116 @@ class FlattenGradOp : public framework::OperatorBase {
157157
}
158158
};
159159

160+
// FIXME(zcd): flatten2 adds an intermediate output(XShape) based on flatten,
161+
// the XShape is used to carry the shape and lod of X which will be used in
162+
// flatten_grad, in this way, the framework can reuse the memory of X
163+
// immediately the flatten2_op is finished.
164+
// Considering compatibility issues, we could not fix flatten2_op
165+
class Flatten2OpInferShape : public FlattenOpInferShape {
166+
public:
167+
void operator()(framework::InferShapeContext *ctx) const override {
168+
FlattenOpInferShape::operator()(ctx);
169+
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
170+
"Output (XShape) of Flatten op should not be null.");
171+
const auto &in_dims = ctx->GetInputDim("X");
172+
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
173+
xshape_dims[0] = 0;
174+
for (int i = 0; i < in_dims.size(); ++i) {
175+
xshape_dims[i + 1] = in_dims[i];
176+
}
177+
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
178+
ctx->ShareLoD("X", "XShape");
179+
}
180+
};
181+
182+
class Flatten2Op : public framework::OperatorBase {
183+
public:
184+
using OperatorBase::OperatorBase;
185+
186+
private:
187+
void RunImpl(const framework::Scope &scope,
188+
const platform::Place &place) const override {
189+
auto &axis = Attr<int>("axis");
190+
auto in_dims =
191+
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
192+
const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims);
193+
194+
framework::AttributeMap attrs;
195+
attrs["shape"] = out_dims;
196+
attrs["inplace"] = false;
197+
// Invoke Reshape Op
198+
auto reshape_op = framework::OpRegistry::CreateOp(
199+
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
200+
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
201+
reshape_op->Run(scope, place);
202+
}
203+
};
204+
205+
class Flatten2OpMaker : public FlattenOpMaker {
206+
public:
207+
void Make() override {
208+
FlattenOpMaker::Make();
209+
AddOutput("XShape",
210+
"XShape is just used to store the shape and lod of X, which will "
211+
"be used in FlattenGradOp.")
212+
.AsIntermediate();
213+
}
214+
};
215+
216+
class Flatten2GradOpMaker : public framework::SingleGradOpDescMaker {
217+
public:
218+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
219+
220+
std::unique_ptr<framework::OpDesc> Apply() const override {
221+
auto *grad_op = new framework::OpDesc();
222+
grad_op->SetType("flatten2_grad");
223+
grad_op->SetInput("XShape", Output("XShape"));
224+
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
225+
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
226+
grad_op->SetAttrMap(Attrs());
227+
return std::unique_ptr<framework::OpDesc>(grad_op);
228+
}
229+
};
230+
231+
class Flatten2GradInferShape : public framework::InferShapeBase {
232+
public:
233+
void operator()(framework::InferShapeContext *context) const override {
234+
PADDLE_ENFORCE(context->HasInput("XShape"),
235+
"Input(XShape) shouldn't be null.");
236+
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
237+
"Input(Out@GRAD) shouldn't be null.");
238+
auto xshape_dims = context->GetInputDim("XShape");
239+
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
240+
context->SetOutputDim(framework::GradVarName("X"), x_dims);
241+
context->ShareLoD("XShape", framework::GradVarName("X"));
242+
}
243+
};
244+
245+
class Flatten2GradOp : public framework::OperatorBase {
246+
public:
247+
using OperatorBase::OperatorBase;
248+
249+
private:
250+
void RunImpl(const framework::Scope &scope,
251+
const platform::Place &place) const override {
252+
auto dx_name = Output(framework::GradVarName("X"));
253+
auto dout_name = Input(framework::GradVarName("Out"));
254+
auto xshape_name = Input("XShape");
255+
auto xshape_dims =
256+
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
257+
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
258+
259+
framework::AttributeMap attrs;
260+
attrs["shape"] = framework::vectorize2int(x_dims);
261+
attrs["inplace"] = false;
262+
263+
auto reshape_op = framework::OpRegistry::CreateOp(
264+
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
265+
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
266+
reshape_op->Run(scope, place);
267+
}
268+
};
269+
160270
} // namespace operators
161271
} // namespace paddle
162272

@@ -167,3 +277,8 @@ REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
167277
ops::FlattenOpInferShape,
168278
paddle::framework::DefaultGradOpDescMaker<true>);
169279
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape);
280+
281+
REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
282+
ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker);
283+
REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
284+
ops::Flatten2GradInferShape);

paddle/fluid/operators/reshape_op.cc

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,88 @@ class ReshapeGradKernel {
246246
}
247247
};
248248

249+
// FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape,
250+
// the XShape is used to carry the shape and lod of X which will be used in
251+
// reshape_grad, in this way, the framework can reuse the memory of X
252+
// immediately the reshape_op is finished.
253+
// Considering compatibility issues, we could not fix reshape_op
254+
class Reshape2Op : public ReshapeOp {
255+
public:
256+
Reshape2Op(const std::string &type, const framework::VariableNameMap &inputs,
257+
const framework::VariableNameMap &outputs,
258+
const framework::AttributeMap &attrs)
259+
: ReshapeOp(type, inputs, outputs, attrs) {}
260+
261+
void InferShape(framework::InferShapeContext *ctx) const override {
262+
ReshapeOp::InferShape(ctx);
263+
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
264+
"Output(XShape) of ReshapeOp should not be null.");
265+
const auto &x_dims = ctx->GetInputDim("X");
266+
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
267+
xshape_dims[0] = 0;
268+
for (int i = 0; i < x_dims.size(); ++i) {
269+
xshape_dims[i + 1] = x_dims[i];
270+
}
271+
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
272+
ctx->ShareLoD("X", /*->*/ "XShape");
273+
}
274+
};
275+
276+
class Reshape2OpMaker : public ReshapeOpMaker {
277+
public:
278+
void Make() override {
279+
ReshapeOpMaker::Make();
280+
AddOutput("XShape",
281+
"XShape is just used to store the shape and lod of X, which will "
282+
"be used in FlattenGradOp.")
283+
.AsIntermediate();
284+
}
285+
};
286+
287+
class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
288+
public:
289+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
290+
291+
std::unique_ptr<framework::OpDesc> Apply() const override {
292+
auto *grad_op = new framework::OpDesc();
293+
grad_op->SetType("reshape2_grad");
294+
grad_op->SetInput("XShape", Output("XShape"));
295+
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
296+
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
297+
grad_op->SetAttrMap(Attrs());
298+
return std::unique_ptr<framework::OpDesc>(grad_op);
299+
}
300+
};
301+
302+
class Reshape2GradOp : public framework::OperatorWithKernel {
303+
public:
304+
Reshape2GradOp(const std::string &type,
305+
const framework::VariableNameMap &inputs,
306+
const framework::VariableNameMap &outputs,
307+
const framework::AttributeMap &attrs)
308+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
309+
310+
void InferShape(framework::InferShapeContext *ctx) const override {
311+
PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) shouldn't be null.");
312+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
313+
"Input(Out@GRAD) shouldn't be null.");
314+
auto xshape_dims = ctx->GetInputDim("XShape");
315+
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
316+
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
317+
ctx->ShareLoD("XShape", framework::GradVarName("X"));
318+
}
319+
320+
protected:
321+
framework::OpKernelType GetExpectedKernelType(
322+
const framework::ExecutionContext &ctx) const override {
323+
return framework::OpKernelType(
324+
framework::ToDataType(
325+
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
326+
->type()),
327+
ctx.device_context());
328+
}
329+
};
330+
249331
} // namespace operators
250332
} // namespace paddle
251333
namespace ops = paddle::operators;
@@ -261,6 +343,17 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
261343
ops::ReshapeGradKernel, int64_t,
262344
ops::ReshapeGradKernel);
263345

346+
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
347+
ops::Reshape2GradMaker);
348+
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
349+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
350+
ops::ReshapeKernel, int, ops::ReshapeKernel,
351+
int64_t, ops::ReshapeKernel);
352+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
353+
double, ops::ReshapeGradKernel, int,
354+
ops::ReshapeGradKernel, int64_t,
355+
ops::ReshapeGradKernel);
356+
264357
#ifdef PADDLE_WITH_CUDA
265358
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
266359
ops::ReshapeKernel, int, ops::ReshapeKernel,
@@ -269,4 +362,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
269362
double, ops::ReshapeGradKernel, int,
270363
ops::ReshapeGradKernel, int64_t,
271364
ops::ReshapeGradKernel);
365+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
366+
ops::ReshapeKernel, int, ops::ReshapeKernel,
367+
int64_t, ops::ReshapeKernel);
368+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
369+
double, ops::ReshapeGradKernel, int,
370+
ops::ReshapeGradKernel, int64_t,
371+
ops::ReshapeGradKernel);
272372
#endif

paddle/fluid/operators/squeeze_op.cc

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,15 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
126126
.SetDefault({});
127127
AddComment(R"DOC(
128128
Squeeze Operator.
129-
130-
Remove single-dimensional entries from the shape of a tensor.
131-
Takes a parameter axes with a list of axes to squeeze.
132-
If axes is not provided, all the single dimensions will be removed from the shape.
129+
130+
Remove single-dimensional entries from the shape of a tensor.
131+
Takes a parameter axes with a list of axes to squeeze.
132+
If axes is not provided, all the single dimensions will be removed from the shape.
133133
If an axis is selected with shape entry not equal to one, an error is raised.
134-
134+
135135
Examples:
136136
Case 1:
137-
Given
137+
Given
138138
X.shape = (1, 3, 1, 5)
139139
and
140140
axes = [0]
@@ -144,7 +144,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
144144
Case 2:
145145
Given
146146
X.shape = (1, 3, 1, 5)
147-
and
147+
and
148148
axes = []
149149
we get:
150150
Out.shape = (3, 5)
@@ -181,6 +181,113 @@ class SqueezeGradOp : public framework::OperatorBase {
181181
}
182182
};
183183

184+
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
185+
// the XShape is used to carry the shape and lod of X which will be used in
186+
// squeeze_grad, in this way, the framework can reuse the memory of X
187+
// immediately the squeeze2_op is finished.
188+
// Considering compatibility issues, we could not fix squeeze2_op
189+
class Squeeze2OpMaker : public SqueezeOpMaker {
190+
public:
191+
void Make() override {
192+
SqueezeOpMaker::Make();
193+
AddOutput("XShape",
194+
"XShape is just used to store the shape and lod of X, which will "
195+
"be used in SqueezeGradOp.")
196+
.AsIntermediate();
197+
}
198+
};
199+
200+
class Squeeze2OpInferShape : public SqueezeOpInferShape {
201+
public:
202+
void operator()(framework::InferShapeContext *ctx) const override {
203+
SqueezeOpInferShape::operator()(ctx);
204+
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
205+
"Output(XShape) of Squeeze operator should not be null.");
206+
const auto &x_dims = ctx->GetInputDim("X");
207+
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
208+
xshape_dims[0] = 0;
209+
for (int i = 0; i < x_dims.size(); ++i) {
210+
xshape_dims[i + 1] = x_dims[i];
211+
}
212+
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
213+
ctx->ShareLoD("X", /*->*/ "XShape");
214+
}
215+
};
216+
217+
class Squeeze2Op : public framework::OperatorBase {
218+
public:
219+
using OperatorBase::OperatorBase;
220+
221+
private:
222+
void RunImpl(const framework::Scope &scope,
223+
const platform::Place &place) const override {
224+
auto &axes = Attr<std::vector<int>>("axes");
225+
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
226+
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims);
227+
228+
framework::AttributeMap attrs;
229+
attrs["shape"] = framework::vectorize2int(out_dims);
230+
// Invoke Reshape Op
231+
auto reshape_op = framework::OpRegistry::CreateOp(
232+
"reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
233+
{{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
234+
reshape_op->Run(scope, place);
235+
}
236+
};
237+
238+
class Squeeze2GradOpMaker : public framework::SingleGradOpDescMaker {
239+
public:
240+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
241+
242+
std::unique_ptr<framework::OpDesc> Apply() const override {
243+
auto *grad_op = new framework::OpDesc();
244+
grad_op->SetType("squeeze2_grad");
245+
grad_op->SetInput("XShape", Output("XShape"));
246+
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
247+
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
248+
grad_op->SetAttrMap(Attrs());
249+
return std::unique_ptr<framework::OpDesc>(grad_op);
250+
}
251+
};
252+
253+
class Squeeze2GradInferShape : public framework::InferShapeBase {
254+
public:
255+
void operator()(framework::InferShapeContext *context) const override {
256+
PADDLE_ENFORCE(context->HasInput("XShape"),
257+
"Input(XShape) shouldn't be null.");
258+
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
259+
"Input(Out@GRAD) shouldn't be null.");
260+
auto xshape_dims = context->GetInputDim("XShape");
261+
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
262+
context->SetOutputDim(framework::GradVarName("X"), x_dims);
263+
context->ShareLoD("XShape", framework::GradVarName("X"));
264+
}
265+
};
266+
267+
class Squeeze2GradOp : public framework::OperatorBase {
268+
public:
269+
using OperatorBase::OperatorBase;
270+
271+
private:
272+
void RunImpl(const framework::Scope &scope,
273+
const platform::Place &place) const override {
274+
auto dx_name = Output(framework::GradVarName("X"));
275+
auto dout_name = Input(framework::GradVarName("Out"));
276+
auto xshape_name = Input("XShape");
277+
auto xshape_dims =
278+
scope.FindVar(xshape_name)->Get<framework::LoDTensor>().dims();
279+
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
280+
281+
framework::AttributeMap attrs;
282+
attrs["shape"] = framework::vectorize2int(x_dims);
283+
284+
auto reshape_op = framework::OpRegistry::CreateOp(
285+
"reshape2", {{"X", {dout_name}}, {"Shape", {}}},
286+
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
287+
reshape_op->Run(scope, place);
288+
}
289+
};
290+
184291
} // namespace operators
185292
} // namespace paddle
186293

@@ -192,3 +299,8 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
192299
ops::SqueezeOpInferShape,
193300
paddle::framework::DefaultGradOpDescMaker<true>);
194301
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
302+
303+
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
304+
ops::Squeeze2OpInferShape, ops::Squeeze2GradOpMaker);
305+
REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
306+
ops::Squeeze2GradInferShape);

0 commit comments

Comments
 (0)